Skip to main content

atomic_progress/
progress.rs

1//! Core primitives for tracking progress state.
2//!
3//! This module defines the [`Progress`] struct, which acts as the central handle for
4//! updates. It is designed around a "Hot/Cold" split to maximize performance in
5//! multi-threaded environments:
6//!
7//! * **Hot Data:** Position, Total, and Finished state are stored in `Atomic` primitives.
8//!   This allows high-frequency updates (e.g., in tight loops) without locking contention.
9//! * **Cold Data:** Metadata like names, current items, and error states are guarded by
10//!   an [`RwLock`](parking_lot::RwLock). These are accessed less frequently, typically
11//!   only by the rendering thread or when significant state changes occur.
12//!
13//! # Snapshots
14//!
15//! To render progress safely, use [`Progress::snapshot`] to obtain a [`ProgressSnapshot`].
16//! This provides a consistent, immutable view of the progress state at a specific instant,
17//! calculating derived metrics like ETA and throughput automatically.
18
19use std::{
20    sync::{
21        Arc,
22        atomic::{AtomicBool, AtomicU64, Ordering},
23    },
24    time::Duration,
25};
26
27use compact_str::CompactString;
28use parking_lot::RwLock;
29use web_time::Instant;
30
31/// A thread-safe, cloneable handle to a progress indicator.
32///
33/// `Progress` separates "hot" data (position, total, finished status) which are stored in
34/// atomics for high-performance updates, from "cold" data (names, errors, timing) which are
35/// guarded by an [`RwLock`].
36///
37/// Cloning a `Progress` is cheap (Arc bump) and points to the same underlying state.
38#[derive(Clone)]
39pub struct Progress {
40    /// The type of progress indicator (Bar vs Spinner). Immutable after creation.
41    pub(crate) kind: ProgressType,
42
43    /// The instant the progress tracker was created/started.
44    pub(crate) start: Option<Instant>,
45
46    /// Infrequently accessed metadata (name, error state, stop time).
47    pub(crate) cold: Arc<RwLock<Cold>>,
48
49    /// The current "item" being processed (e.g., filename).
50    pub(crate) item: Arc<RwLock<CompactString>>,
51
52    // Atomic fields for wait-free updates on the hot path.
53    pub(crate) position: Arc<AtomicU64>,
54    pub(crate) total: Arc<AtomicU64>,
55    pub(crate) finished: Arc<AtomicBool>,
56}
57
58/// "Cold" storage for metadata that changes infrequently.
59pub struct Cold {
60    pub(crate) name: CompactString,
61    pub(crate) stopped: Option<Instant>,
62    pub(crate) error: Option<CompactString>,
63}
64
65/// Defines the behavior/visualization hint for the progress indicator.
66#[repr(u8)]
67#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
68#[cfg_attr(
69    feature = "rkyv",
70    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
71)]
72#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
73#[cfg_attr(feature = "rkyv", rkyv(derive(Debug, Eq, PartialEq)))]
74pub enum ProgressType {
75    /// A spinner, used when the total number of items is unknown.
76    #[default]
77    Spinner,
78    /// A progress bar, used when the total is known.
79    Bar,
80}
81
82/// Slightly dangerous, but convenient to derive inside larger structs.
83/// Be sure to only use this for debugging purposes, as it may not reflect the most up-to-date state if other threads are modifying it concurrently.
84/// This is not suitable for production code or any logic that relies on consistent state.
85impl std::fmt::Debug for Progress {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        let cold = self.cold.read();
88        f.debug_struct("Progress")
89            .field("kind", &self.kind)
90            .field("start", &self.start)
91            .field("name", &cold.name)
92            .field("item", &self.item.read())
93            .field("position", &self.position.load(Ordering::Relaxed))
94            .field("total", &self.total.load(Ordering::Relaxed))
95            .field("finished", &self.finished.load(Ordering::Relaxed))
96            .field("error", &cold.error)
97            .finish()
98    }
99}
100
101impl Progress {
102    /// Creates a new `Progress` instance.
103    ///
104    /// # Parameters
105    ///
106    /// * `kind`: The type of indicator.
107    /// * `name`: A label for the task.
108    /// * `total`: The total expected count (use 0 for spinners).
109    pub fn new(kind: ProgressType, name: impl Into<CompactString>, total: impl Into<u64>) -> Self {
110        Self {
111            kind,
112            start: None,
113            cold: Arc::new(RwLock::new(Cold {
114                name: name.into(),
115                stopped: None,
116                error: None,
117            })),
118            item: Arc::new(RwLock::new(CompactString::default())),
119            position: Arc::new(AtomicU64::new(0)),
120            total: Arc::new(AtomicU64::new(total.into())),
121            finished: Arc::new(AtomicBool::new(false)),
122        }
123    }
124
125    /// Creates a new generic progress bar with a known total.
126    #[must_use]
127    pub fn new_pb(name: impl Into<CompactString>, total: impl Into<u64>) -> Self {
128        Self::new(ProgressType::Bar, name, total)
129    }
130
131    /// Creates a new spinner (indeterminate progress).
132    #[must_use]
133    pub fn new_spinner(name: impl Into<CompactString>) -> Self {
134        Self::new(ProgressType::Spinner, name, 0u64)
135    }
136
137    // ========================================================================
138    // Metadata Accessors
139    // ========================================================================
140
141    /// Gets the current name/label of the progress task.
142    #[must_use]
143    pub fn get_name(&self) -> CompactString {
144        self.cold.read().name.clone()
145    }
146
147    /// Updates the name/label of the progress task.
148    pub fn set_name(&self, name: impl Into<CompactString>) {
149        self.cold.write().name = name.into();
150    }
151
152    /// Gets the current item description (e.g., currently processing file).
153    #[must_use]
154    pub fn get_item(&self) -> CompactString {
155        self.item.read().clone()
156    }
157
158    /// Updates the current item description.
159    pub fn set_item(&self, item: impl Into<CompactString>) {
160        *self.item.write() = item.into();
161    }
162
163    /// Returns the error message, if one occurred.
164    #[must_use]
165    pub fn get_error(&self) -> Option<CompactString> {
166        self.cold.read().error.clone()
167    }
168
169    /// Sets (or clears) an error message for this task.
170    pub fn set_error(&self, error: Option<impl Into<CompactString>>) {
171        let error = error.map(Into::into);
172        self.cold.write().error = error;
173    }
174
175    // ========================================================================
176    // State & Metrics (Hot Path)
177    // ========================================================================
178
179    /// Increments the progress position by the specified amount.
180    ///
181    /// This uses `Ordering::Relaxed` for maximum performance.
182    pub fn inc(&self, amount: impl Into<u64>) {
183        self.position.fetch_add(amount.into(), Ordering::Relaxed);
184    }
185
186    /// Increments the progress position by 1.
187    pub fn bump(&self) {
188        self.inc(1u64);
189    }
190
191    /// Gets the current position.
192    #[must_use]
193    pub fn get_pos(&self) -> u64 {
194        self.position.load(Ordering::Relaxed)
195    }
196
197    /// Sets the absolute position.
198    pub fn set_pos(&self, pos: u64) {
199        self.position.store(pos, Ordering::Relaxed);
200    }
201
202    /// Gets the total target count.
203    #[must_use]
204    pub fn get_total(&self) -> u64 {
205        self.total.load(Ordering::Relaxed)
206    }
207
208    /// Updates the total target count.
209    pub fn set_total(&self, total: u64) {
210        self.total.store(total, Ordering::Relaxed);
211    }
212
213    /// Checks if the task is marked as finished.
214    #[must_use]
215    pub fn is_finished(&self) -> bool {
216        // Acquire ensures we see any memory writes that happened before the finish flag was set.
217        self.finished.load(Ordering::Acquire)
218    }
219
220    /// Manually sets the finished state.
221    ///
222    /// Prefer using [`finish`](Self::finish), [`finish_with_item`](Self::finish_with_item),
223    /// or [`finish_with_error`](Self::finish_with_error) to ensure timestamps are recorded.
224    pub fn set_finished(&self, finished: bool) {
225        self.finished.store(finished, Ordering::Release);
226    }
227
228    // ========================================================================
229    // Timing & Calculations
230    // ========================================================================
231
232    /// Calculates the duration elapsed since creation.
233    ///
234    /// If the task is finished, this returns the duration between start and finish.
235    /// If never started (no start time recorded), returns `None`.
236    #[must_use]
237    pub fn get_elapsed(&self) -> Option<Duration> {
238        let start = self.start?;
239        let cold = self.cold.read();
240
241        Some(
242            cold.stopped
243                .map_or_else(|| start.elapsed(), |stopped| stopped.duration_since(start)),
244        )
245    }
246
247    /// Returns the current completion percentage (0.0 to 100.0).
248    ///
249    /// Returns `0.0` if `total` is zero.
250    #[allow(clippy::cast_precision_loss)]
251    #[must_use]
252    pub fn get_percent(&self) -> f64 {
253        let pos = self.get_pos() as f64;
254        let total = self.get_total() as f64;
255
256        if total == 0.0 {
257            0.0
258        } else {
259            (pos / total) * 100.0
260        }
261    }
262
263    // ========================================================================
264    // Lifecycle Management
265    // ========================================================================
266
267    /// Marks the task as finished and records the stop time.
268    pub fn finish(&self) {
269        if self.start.is_some() {
270            self.cold.write().stopped.replace(Instant::now());
271        }
272        self.set_finished(true);
273    }
274
275    /// Sets the current item and marks the task as finished.
276    pub fn finish_with_item(&self, item: impl Into<CompactString>) {
277        self.set_item(item);
278        self.finish(); // Calls set_finished(true) internally
279    }
280
281    /// Sets an error message and marks the task as finished.
282    pub fn finish_with_error(&self, error: impl Into<CompactString>) {
283        self.set_error(Some(error));
284        self.finish();
285    }
286
287    // ========================================================================
288    // Advanced / Internal
289    // ========================================================================
290
291    /// Returns a shared reference to the atomic position counter.
292    ///
293    /// Useful for sharing this specific counter with other systems.
294    #[must_use]
295    pub fn atomic_pos(&self) -> Arc<AtomicU64> {
296        self.position.clone()
297    }
298
299    /// Returns a shared reference to the atomic total counter.
300    #[must_use]
301    pub fn atomic_total(&self) -> Arc<AtomicU64> {
302        self.total.clone()
303    }
304
305    /// Creates a consistent snapshot of the current state.
306    ///
307    /// This involves acquiring a read lock on the "cold" data.
308    #[must_use]
309    pub fn snapshot(&self) -> ProgressSnapshot {
310        self.into()
311    }
312}
313
314/// A plain-data snapshot of a [`Progress`] state at a specific point in time.
315///
316/// This is typically used for rendering, as it holds owned data and requires no locking
317/// to access.
318#[derive(Clone, Debug, Default, Eq, PartialEq)]
319#[cfg_attr(
320    feature = "rkyv",
321    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
322)]
323#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
324#[cfg_attr(feature = "rkyv", rkyv(derive(Debug, Eq, PartialEq)))]
325pub struct ProgressSnapshot {
326    /// The type of progress indicator.
327    pub kind: ProgressType,
328
329    /// The name/label of the progress task.
330    pub name: CompactString,
331    /// The current item description.
332    pub item: CompactString,
333
334    /// The elapsed duration.
335    pub elapsed: Option<Duration>,
336
337    /// The current position.
338    pub position: u64,
339    /// The total target count.
340    pub total: u64,
341
342    /// Whether the task is finished.
343    pub finished: bool,
344
345    /// The associated error message, if any.
346    pub error: Option<CompactString>,
347}
348
349impl From<&Progress> for ProgressSnapshot {
350    fn from(progress: &Progress) -> Self {
351        // Lock cold data once
352        let cold = progress.cold.read();
353        let name = cold.name.clone();
354        let error = cold.error.clone();
355        drop(cold);
356
357        Self {
358            kind: progress.kind,
359            name,
360            item: progress.item.read().clone(),
361            elapsed: progress.get_elapsed(),
362            position: progress.position.load(Ordering::Relaxed),
363            total: progress.total.load(Ordering::Relaxed),
364            finished: progress.finished.load(Ordering::Relaxed),
365            error,
366        }
367    }
368}
369
370impl ProgressSnapshot {
371    /// Returns the type of progress indicator.
372    #[must_use]
373    pub const fn kind(&self) -> ProgressType {
374        self.kind
375    }
376
377    /// Returns the name/label of the progress task.
378    #[must_use]
379    pub fn name(&self) -> &str {
380        &self.name
381    }
382    /// Returns the current item description.
383    #[must_use]
384    pub fn item(&self) -> &str {
385        &self.item
386    }
387
388    /// Returns the elapsed duration.
389    #[must_use]
390    pub const fn elapsed(&self) -> Option<Duration> {
391        self.elapsed
392    }
393
394    /// Returns the current position.
395    #[must_use]
396    pub const fn position(&self) -> u64 {
397        self.position
398    }
399    /// Returns the total target count.
400    #[must_use]
401    pub const fn total(&self) -> u64 {
402        self.total
403    }
404
405    /// Returns whether the task is finished.
406    #[must_use]
407    pub const fn finished(&self) -> bool {
408        self.finished
409    }
410
411    /// Returns the error message, if any.
412    #[must_use]
413    pub fn error(&self) -> Option<&str> {
414        self.error.as_deref()
415    }
416
417    /// Estimates the time remaining (ETA) based on average speed since start.
418    ///
419    /// Returns `None` if:
420    /// * No progress has been made.
421    /// * Total is zero.
422    /// * Process is finished.
423    /// * Elapsed time is effectively zero.
424    #[allow(clippy::cast_precision_loss)]
425    #[must_use]
426    pub fn eta(&self) -> Option<Duration> {
427        if self.position == 0 || self.total == 0 || self.finished {
428            return None;
429        }
430
431        let elapsed = self.elapsed?;
432        let secs = elapsed.as_secs_f64();
433
434        // Avoid division by zero or extremely small intervals
435        if secs <= 1e-6 {
436            return None;
437        }
438
439        let rate = self.position as f64 / secs;
440        if rate <= 0.0 {
441            return None;
442        }
443
444        let remaining_items = self.total.saturating_sub(self.position);
445        let remaining_secs = remaining_items as f64 / rate;
446
447        Some(Duration::from_secs_f64(remaining_secs))
448    }
449
450    /// Calculates the average throughput (items per second) over the entire lifetime.
451    #[allow(clippy::cast_precision_loss)]
452    #[must_use]
453    pub fn throughput(&self) -> f64 {
454        if let Some(elapsed) = self.elapsed {
455            let secs = elapsed.as_secs_f64();
456            if secs > 0.0 {
457                return self.position as f64 / secs;
458            }
459        }
460        0.0
461    }
462
463    /// Calculates the instantaneous throughput relative to a previous snapshot.
464    ///
465    /// This is useful for calculating "current speed" (e.g., in the last second).
466    #[allow(clippy::cast_precision_loss)]
467    #[must_use]
468    pub fn throughput_since(&self, prev: &Self) -> f64 {
469        let pos_diff = self.position.saturating_sub(prev.position) as f64;
470
471        let time_diff = match (self.elapsed, prev.elapsed) {
472            (Some(curr), Some(old)) => curr.as_secs_f64() - old.as_secs_f64(),
473            _ => 0.0,
474        };
475
476        if time_diff > 0.0 {
477            pos_diff / time_diff
478        } else {
479            0.0
480        }
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use std::thread;
487
488    use super::Progress;
489
490    /// Basic Lifecycle
491    /// Verifies the fundamental state machine: New -> Inc -> Finish.
492    #[test]
493    #[allow(clippy::float_cmp)]
494    fn test_basic_lifecycle() {
495        let p = Progress::new_pb("test_job", 100u64);
496
497        assert_eq!(p.get_pos(), 0);
498        assert!(!p.is_finished());
499        assert_eq!(p.get_percent(), 0.0);
500
501        p.inc(50u64);
502        assert_eq!(p.get_pos(), 50);
503        assert_eq!(p.get_percent(), 50.0);
504
505        p.finish();
506        assert!(p.is_finished());
507
508        // Default constructor does not start the timer; elapsed should be None.
509        assert!(p.get_elapsed().is_none());
510    }
511
512    /// Concurrency & Atomics
513    /// Ensures that high-contention updates from multiple threads are lossless.
514    #[test]
515    fn test_concurrency_atomics() {
516        let p = Progress::new_spinner("concurrent_job");
517        let mut handles = vec![];
518
519        // Spawn 10 threads, each incrementing 100 times
520        for _ in 0..10 {
521            let p_ref = p.clone();
522            handles.push(thread::spawn(move || {
523                for _ in 0..100 {
524                    p_ref.inc(1u64);
525                }
526            }));
527        }
528
529        for h in handles {
530            h.join().unwrap();
531        }
532
533        assert_eq!(p.get_pos(), 1000, "Atomic updates should be lossless");
534    }
535
536    /// Snapshot Metadata
537    /// Verifies that "Cold" data (names, errors) propagates to snapshots correctly.
538    #[test]
539    fn test_snapshot_metadata() {
540        let p = Progress::new_pb("initial_name", 100u64);
541
542        // Mutate cold state
543        p.set_name("updated_name");
544        p.set_item("file_a.txt");
545        p.set_error(Some("disk_full"));
546
547        let snap = p.snapshot();
548
549        assert_eq!(snap.name, "updated_name");
550        assert_eq!(snap.item, "file_a.txt");
551        assert_eq!(snap.error, Some("disk_full".into()));
552    }
553
554    /// Throughput & ETA Safety
555    /// Verifies mathematical correctness and edge-case safety (NaN/Inf checks).
556    #[allow(clippy::float_cmp)]
557    #[test]
558    fn test_math_safety() {
559        let p = Progress::new_pb("math_test", 100u64);
560        let snap = p.snapshot();
561
562        // Edge case: No time elapsed, no progress
563        assert_eq!(snap.throughput(), 0.0);
564        assert!(snap.eta().is_none());
565
566        // We can't easily mock time without dependency injection or sleeping.
567        // We settle for verifying that 0 total handles percentage gracefully.
568        let p_zero = Progress::new_pb("zero_total", 0u64);
569        assert_eq!(p_zero.get_percent(), 0.0);
570    }
571}