atomic_progress/
progress.rs

1//! Core primitives for tracking progress state.
2//!
3//! This module defines the [`Progress`] struct, which acts as the central handle for
4//! updates. It is designed around a "Hot/Cold" split to maximize performance in
5//! multi-threaded environments:
6//!
7//! * **Hot Data:** Position, Total, and Finished state are stored in `Atomic` primitives.
8//!   This allows high-frequency updates (e.g., in tight loops) without locking contention.
9//! * **Cold Data:** Metadata like names, current items, and error states are guarded by
10//!   an [`RwLock`](parking_lot::RwLock). These are accessed less frequently, typically
11//!   only by the rendering thread or when significant state changes occur.
12//!
13//! # Snapshots
14//!
15//! To render progress safely, use [`Progress::snapshot`] to obtain a [`ProgressSnapshot`].
16//! This provides a consistent, immutable view of the progress state at a specific instant,
17//! calculating derived metrics like ETA and throughput automatically.
18
19use std::{
20    sync::{
21        Arc,
22        atomic::{AtomicBool, AtomicU64, Ordering},
23    },
24    time::Duration,
25};
26
27use compact_str::CompactString;
28use parking_lot::RwLock;
29use web_time::Instant;
30
31/// A thread-safe, cloneable handle to a progress indicator.
32///
33/// `Progress` separates "hot" data (position, total, finished status) which are stored in
34/// atomics for high-performance updates, from "cold" data (names, errors, timing) which are
35/// guarded by an [`RwLock`].
36///
37/// Cloning a `Progress` is cheap (Arc bump) and points to the same underlying state.
38#[derive(Clone)]
39pub struct Progress {
40    /// The type of progress indicator (Bar vs Spinner). Immutable after creation.
41    pub(crate) kind: ProgressType,
42
43    /// The instant the progress tracker was created/started.
44    pub(crate) start: Option<Instant>,
45
46    /// Infrequently accessed metadata (name, error state, stop time).
47    pub(crate) cold: Arc<RwLock<Cold>>,
48
49    /// The current "item" being processed (e.g., filename).
50    pub(crate) item: Arc<RwLock<CompactString>>,
51
52    // Atomic fields for wait-free updates on the hot path.
53    pub(crate) position: Arc<AtomicU64>,
54    pub(crate) total: Arc<AtomicU64>,
55    pub(crate) finished: Arc<AtomicBool>,
56}
57
58/// "Cold" storage for metadata that changes infrequently.
59pub struct Cold {
60    pub(crate) name: CompactString,
61    pub(crate) stopped: Option<Instant>,
62    pub(crate) error: Option<CompactString>,
63}
64
65/// Defines the behavior/visualization hint for the progress indicator.
66#[repr(u8)]
67#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
68#[cfg_attr(
69    feature = "rkyv",
70    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
71)]
72#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
73#[cfg_attr(feature = "rkyv", rkyv(derive(Debug, Eq, PartialEq)))]
74pub enum ProgressType {
75    /// A spinner, used when the total number of items is unknown.
76    #[default]
77    Spinner,
78    /// A progress bar, used when the total is known.
79    Bar,
80}
81
82impl Progress {
83    /// Creates a new `Progress` instance.
84    ///
85    /// # Parameters
86    ///
87    /// * `kind`: The type of indicator.
88    /// * `name`: A label for the task.
89    /// * `total`: The total expected count (use 0 for spinners).
90    pub fn new(kind: ProgressType, name: impl Into<CompactString>, total: impl Into<u64>) -> Self {
91        Self {
92            kind,
93            start: None,
94            cold: Arc::new(RwLock::new(Cold {
95                name: name.into(),
96                stopped: None,
97                error: None,
98            })),
99            item: Arc::new(RwLock::new(CompactString::default())),
100            position: Arc::new(AtomicU64::new(0)),
101            total: Arc::new(AtomicU64::new(total.into())),
102            finished: Arc::new(AtomicBool::new(false)),
103        }
104    }
105
106    /// Creates a new generic progress bar with a known total.
107    #[must_use]
108    pub fn new_pb(name: impl Into<CompactString>, total: impl Into<u64>) -> Self {
109        Self::new(ProgressType::Bar, name, total)
110    }
111
112    /// Creates a new spinner (indeterminate progress).
113    #[must_use]
114    pub fn new_spinner(name: impl Into<CompactString>) -> Self {
115        Self::new(ProgressType::Spinner, name, 0u64)
116    }
117
118    // ========================================================================
119    // Metadata Accessors
120    // ========================================================================
121
122    /// Gets the current name/label of the progress task.
123    #[must_use]
124    pub fn get_name(&self) -> CompactString {
125        self.cold.read().name.clone()
126    }
127
128    /// Updates the name/label of the progress task.
129    pub fn set_name(&self, name: impl Into<CompactString>) {
130        self.cold.write().name = name.into();
131    }
132
133    /// Gets the current item description (e.g., currently processing file).
134    #[must_use]
135    pub fn get_item(&self) -> CompactString {
136        self.item.read().clone()
137    }
138
139    /// Updates the current item description.
140    pub fn set_item(&self, item: impl Into<CompactString>) {
141        *self.item.write() = item.into();
142    }
143
144    /// Returns the error message, if one occurred.
145    #[must_use]
146    pub fn get_error(&self) -> Option<CompactString> {
147        self.cold.read().error.clone()
148    }
149
150    /// Sets (or clears) an error message for this task.
151    pub fn set_error(&self, error: Option<impl Into<CompactString>>) {
152        let error = error.map(Into::into);
153        self.cold.write().error = error;
154    }
155
156    // ========================================================================
157    // State & Metrics (Hot Path)
158    // ========================================================================
159
160    /// Increments the progress position by the specified amount.
161    ///
162    /// This uses `Ordering::Relaxed` for maximum performance.
163    pub fn inc(&self, amount: impl Into<u64>) {
164        self.position.fetch_add(amount.into(), Ordering::Relaxed);
165    }
166
167    /// Increments the progress position by 1.
168    pub fn bump(&self) {
169        self.inc(1u64);
170    }
171
172    /// Gets the current position.
173    #[must_use]
174    pub fn get_pos(&self) -> u64 {
175        self.position.load(Ordering::Relaxed)
176    }
177
178    /// Sets the absolute position.
179    pub fn set_pos(&self, pos: u64) {
180        self.position.store(pos, Ordering::Relaxed);
181    }
182
183    /// Gets the total target count.
184    #[must_use]
185    pub fn get_total(&self) -> u64 {
186        self.total.load(Ordering::Relaxed)
187    }
188
189    /// Updates the total target count.
190    pub fn set_total(&self, total: u64) {
191        self.total.store(total, Ordering::Relaxed);
192    }
193
194    /// Checks if the task is marked as finished.
195    #[must_use]
196    pub fn is_finished(&self) -> bool {
197        // Acquire ensures we see any memory writes that happened before the finish flag was set.
198        self.finished.load(Ordering::Acquire)
199    }
200
201    /// Manually sets the finished state.
202    ///
203    /// Prefer using [`finish`](Self::finish), [`finish_with_item`](Self::finish_with_item),
204    /// or [`finish_with_error`](Self::finish_with_error) to ensure timestamps are recorded.
205    pub fn set_finished(&self, finished: bool) {
206        self.finished.store(finished, Ordering::Release);
207    }
208
209    // ========================================================================
210    // Timing & Calculations
211    // ========================================================================
212
213    /// Calculates the duration elapsed since creation.
214    ///
215    /// If the task is finished, this returns the duration between start and finish.
216    /// If never started (no start time recorded), returns `None`.
217    #[must_use]
218    pub fn get_elapsed(&self) -> Option<Duration> {
219        let start = self.start?;
220        let cold = self.cold.read();
221
222        Some(
223            cold.stopped
224                .map_or_else(|| start.elapsed(), |stopped| stopped.duration_since(start)),
225        )
226    }
227
228    /// Returns the current completion percentage (0.0 to 100.0).
229    ///
230    /// Returns `0.0` if `total` is zero.
231    #[allow(clippy::cast_precision_loss)]
232    #[must_use]
233    pub fn get_percent(&self) -> f64 {
234        let pos = self.get_pos() as f64;
235        let total = self.get_total() as f64;
236
237        if total == 0.0 {
238            0.0
239        } else {
240            (pos / total) * 100.0
241        }
242    }
243
244    // ========================================================================
245    // Lifecycle Management
246    // ========================================================================
247
248    /// Marks the task as finished and records the stop time.
249    pub fn finish(&self) {
250        if self.start.is_some() {
251            self.cold.write().stopped.replace(Instant::now());
252        }
253        self.set_finished(true);
254    }
255
256    /// Sets the current item and marks the task as finished.
257    pub fn finish_with_item(&self, item: impl Into<CompactString>) {
258        self.set_item(item);
259        self.finish(); // Calls set_finished(true) internally
260    }
261
262    /// Sets an error message and marks the task as finished.
263    pub fn finish_with_error(&self, error: impl Into<CompactString>) {
264        self.set_error(Some(error));
265        self.finish();
266    }
267
268    // ========================================================================
269    // Advanced / Internal
270    // ========================================================================
271
272    /// Returns a shared reference to the atomic position counter.
273    ///
274    /// Useful for sharing this specific counter with other systems.
275    #[must_use]
276    pub fn atomic_pos(&self) -> Arc<AtomicU64> {
277        self.position.clone()
278    }
279
280    /// Returns a shared reference to the atomic total counter.
281    #[must_use]
282    pub fn atomic_total(&self) -> Arc<AtomicU64> {
283        self.total.clone()
284    }
285
286    /// Creates a consistent snapshot of the current state.
287    ///
288    /// This involves acquiring a read lock on the "cold" data.
289    #[must_use]
290    pub fn snapshot(&self) -> ProgressSnapshot {
291        self.into()
292    }
293}
294
295/// A plain-data snapshot of a [`Progress`] state at a specific point in time.
296///
297/// This is typically used for rendering, as it holds owned data and requires no locking
298/// to access.
299#[derive(Clone, Debug, Default, Eq, PartialEq)]
300#[cfg_attr(
301    feature = "rkyv",
302    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
303)]
304#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
305#[cfg_attr(feature = "rkyv", rkyv(derive(Debug, Eq, PartialEq)))]
306pub struct ProgressSnapshot {
307    kind: ProgressType,
308
309    name: CompactString,
310    item: CompactString,
311
312    elapsed: Option<Duration>,
313
314    position: u64,
315    total: u64,
316
317    finished: bool,
318
319    error: Option<CompactString>,
320}
321
322impl From<&Progress> for ProgressSnapshot {
323    fn from(progress: &Progress) -> Self {
324        // Lock cold data once
325        let cold = progress.cold.read();
326        let name = cold.name.clone();
327        let error = cold.error.clone();
328        drop(cold);
329
330        Self {
331            kind: progress.kind,
332            name,
333            item: progress.item.read().clone(),
334            elapsed: progress.get_elapsed(),
335            position: progress.position.load(Ordering::Relaxed),
336            total: progress.total.load(Ordering::Relaxed),
337            finished: progress.finished.load(Ordering::Relaxed),
338            error,
339        }
340    }
341}
342
343impl ProgressSnapshot {
344    /// Returns the type of progress indicator.
345    #[must_use]
346    pub const fn kind(&self) -> ProgressType {
347        self.kind
348    }
349
350    /// Returns the name/label of the progress task.
351    #[must_use]
352    pub fn name(&self) -> &str {
353        &self.name
354    }
355    /// Returns the current item description.
356    #[must_use]
357    pub fn item(&self) -> &str {
358        &self.item
359    }
360
361    /// Returns the elapsed duration.
362    #[must_use]
363    pub const fn elapsed(&self) -> Option<Duration> {
364        self.elapsed
365    }
366
367    /// Returns the current position.
368    #[must_use]
369    pub const fn position(&self) -> u64 {
370        self.position
371    }
372    /// Returns the total target count.
373    #[must_use]
374    pub const fn total(&self) -> u64 {
375        self.total
376    }
377
378    /// Returns whether the task is finished.
379    #[must_use]
380    pub const fn finished(&self) -> bool {
381        self.finished
382    }
383
384    /// Returns the error message, if any.
385    #[must_use]
386    pub fn error(&self) -> Option<&str> {
387        self.error.as_deref()
388    }
389
390    /// Estimates the time remaining (ETA) based on average speed since start.
391    ///
392    /// Returns `None` if:
393    /// * No progress has been made.
394    /// * Total is zero.
395    /// * Process is finished.
396    /// * Elapsed time is effectively zero.
397    #[allow(clippy::cast_precision_loss)]
398    #[must_use]
399    pub fn eta(&self) -> Option<Duration> {
400        if self.position == 0 || self.total == 0 || self.finished {
401            return None;
402        }
403
404        let elapsed = self.elapsed?;
405        let secs = elapsed.as_secs_f64();
406
407        // Avoid division by zero or extremely small intervals
408        if secs <= 1e-6 {
409            return None;
410        }
411
412        let rate = self.position as f64 / secs;
413        if rate <= 0.0 {
414            return None;
415        }
416
417        let remaining_items = self.total.saturating_sub(self.position);
418        let remaining_secs = remaining_items as f64 / rate;
419
420        Some(Duration::from_secs_f64(remaining_secs))
421    }
422
423    /// Calculates the average throughput (items per second) over the entire lifetime.
424    #[allow(clippy::cast_precision_loss)]
425    #[must_use]
426    pub fn throughput(&self) -> f64 {
427        if let Some(elapsed) = self.elapsed {
428            let secs = elapsed.as_secs_f64();
429            if secs > 0.0 {
430                return self.position as f64 / secs;
431            }
432        }
433        0.0
434    }
435
436    /// Calculates the instantaneous throughput relative to a previous snapshot.
437    ///
438    /// This is useful for calculating "current speed" (e.g., in the last second).
439    #[allow(clippy::cast_precision_loss)]
440    #[must_use]
441    pub fn throughput_since(&self, prev: &Self) -> f64 {
442        let pos_diff = self.position.saturating_sub(prev.position) as f64;
443
444        let time_diff = match (self.elapsed, prev.elapsed) {
445            (Some(curr), Some(old)) => curr.as_secs_f64() - old.as_secs_f64(),
446            _ => 0.0,
447        };
448
449        if time_diff > 0.0 {
450            pos_diff / time_diff
451        } else {
452            0.0
453        }
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use std::thread;
460
461    use super::Progress;
462
463    /// Basic Lifecycle
464    /// Verifies the fundamental state machine: New -> Inc -> Finish.
465    #[test]
466    #[allow(clippy::float_cmp)]
467    fn test_basic_lifecycle() {
468        let p = Progress::new_pb("test_job", 100u64);
469
470        assert_eq!(p.get_pos(), 0);
471        assert!(!p.is_finished());
472        assert_eq!(p.get_percent(), 0.0);
473
474        p.inc(50u64);
475        assert_eq!(p.get_pos(), 50);
476        assert_eq!(p.get_percent(), 50.0);
477
478        p.finish();
479        assert!(p.is_finished());
480
481        // Default constructor does not start the timer; elapsed should be None.
482        assert!(p.get_elapsed().is_none());
483    }
484
485    /// Concurrency & Atomics
486    /// Ensures that high-contention updates from multiple threads are lossless.
487    #[test]
488    fn test_concurrency_atomics() {
489        let p = Progress::new_spinner("concurrent_job");
490        let mut handles = vec![];
491
492        // Spawn 10 threads, each incrementing 100 times
493        for _ in 0..10 {
494            let p_ref = p.clone();
495            handles.push(thread::spawn(move || {
496                for _ in 0..100 {
497                    p_ref.inc(1u64);
498                }
499            }));
500        }
501
502        for h in handles {
503            h.join().unwrap();
504        }
505
506        assert_eq!(p.get_pos(), 1000, "Atomic updates should be lossless");
507    }
508
509    /// Snapshot Metadata
510    /// Verifies that "Cold" data (names, errors) propagates to snapshots correctly.
511    #[test]
512    fn test_snapshot_metadata() {
513        let p = Progress::new_pb("initial_name", 100u64);
514
515        // Mutate cold state
516        p.set_name("updated_name");
517        p.set_item("file_a.txt");
518        p.set_error(Some("disk_full"));
519
520        let snap = p.snapshot();
521
522        assert_eq!(snap.name, "updated_name");
523        assert_eq!(snap.item, "file_a.txt");
524        assert_eq!(snap.error, Some("disk_full".into()));
525    }
526
527    /// Throughput & ETA Safety
528    /// Verifies mathematical correctness and edge-case safety (NaN/Inf checks).
529    #[allow(clippy::float_cmp)]
530    #[test]
531    fn test_math_safety() {
532        let p = Progress::new_pb("math_test", 100u64);
533        let snap = p.snapshot();
534
535        // Edge case: No time elapsed, no progress
536        assert_eq!(snap.throughput(), 0.0);
537        assert!(snap.eta().is_none());
538
539        // We can't easily mock time without dependency injection or sleeping.
540        // We settle for verifying that 0 total handles percentage gracefully.
541        let p_zero = Progress::new_pb("zero_total", 0u64);
542        assert_eq!(p_zero.get_percent(), 0.0);
543    }
544}