progress_token/
lib.rs

1//! A hierarchical progress tracking library that allows monitoring and reporting progress 
2//! of long-running tasks with support for cancellation, status updates, and nested operations.
3//!
4//! # Overview
5//! 
6//! This library provides a [`ProgressToken`] type that combines progress tracking, status updates,
7//! and cancellation into a single unified interface. Unlike a basic [`CancellationToken`], 
8//! `ProgressToken` supports:
9//!
10//! - Hierarchical progress tracking with weighted child operations
11//! - Status message updates that propagate through the hierarchy
12//! - Both determinate (0.0-1.0) and indeterminate progress states
13//! - Multiple subscribers for progress/status updates
14//! - Automatic progress aggregation from child tasks
15//!
16//! # Examples
17//!
18//! Basic usage with progress updates:
19//!
20//! ```rust
21//! use progress_token::ProgressToken;
22//!
23//! #[tokio::main]
24//! async fn main() {
25//!     // Create a root token
26//!     let token = ProgressToken::new("Processing files");
27//!     
28//!     // Update progress as work is done
29//!     token.progress(0.25);
30//!     token.status("Processing file 1/4");
31//!     
32//!     // Subscribe to updates
33//!     let mut updates = token.subscribe();
34//!     while let Some(update) = updates.next().await {
35//!         println!("Progress: {:?}, Status: {}", 
36//!             update.progress, 
37//!             update.status());
38//!     }
39//! }
40//! ```
41//!
42//! Using hierarchical progress with weighted child tasks:
43//!
44//! ```rust
45//! use progress_token::ProgressToken;
46//!
47//! #[tokio::main]
48//! async fn main() {
49//!     let root = ProgressToken::new("Main task");
50//!     
51//!     // Create child tasks with weights
52//!     let process = ProgressToken::child(&root, 0.7, "Processing"); // 70% of total
53//!     let cleanup = ProgressToken::child(&root, 0.3, "Cleanup");    // 30% of total
54//!     
55//!     // Child progress is automatically weighted and propagated
56//!     process.progress(0.5);  // Root progress becomes 0.35 (0.5 * 0.7)
57//!     cleanup.progress(1.0);  // Root progress becomes 0.65 (0.35 + 1.0 * 0.3)
58//! }
59//! ```
60//!
61//! # Status Updates
62//!
63//! The status feature allows tracking the current operation being performed. Status messages
64//! propagate up through the hierarchy, with the most specific (deepest) status being reported:
65//!
66//! ```rust
67//! use progress_token::ProgressToken;
68//!
69//! #[tokio::main]
70//! async fn main() {
71//!     let root = ProgressToken::new("Backup");
72//!     let compress = ProgressToken::child(&root, 1.0, "Compressing files");
73//!     
74//!     // Update status with more specific information
75//!     compress.status("Compressing images/photo1.jpg");
76//!     
77//!     // Get full status hierarchy
78//!     let statuses = root.statuses();
79//!     assert_eq!(statuses, vec![
80//!         "Backup",
81//!         "Compressing files", 
82//!         "Compressing images/photo1.jpg"
83//!     ]);
84//! }
85//! ```
86//!
87//! # Differences from CancellationToken
88//!
89//! While this library includes cancellation support, [`ProgressToken`] provides several
90//! advantages over a basic [`CancellationToken`]:
91//!
92//! - **Rich Status Information**: Track not just whether an operation is cancelled, but its
93//!   current progress and status.
94//! - **Hierarchical Structure**: Create trees of operations where progress and cancellation
95//!   flow naturally through parent-child relationships.
96//! - **Progress Aggregation**: Automatically calculate overall progress from weighted child
97//!   tasks instead of managing it manually.
98//! - **Multiple Subscribers**: Allow multiple parts of your application to monitor progress
99//!   through the subscription system.
100//! - **Determinate/Indeterminate States**: Handle both measurable progress and operations
101//!   where progress cannot be determined.
102//!
103//! # Implementation Notes
104//!
105//! - Progress values are automatically clamped to the range 0.0-1.0
106//! - Child task weights should sum to 1.0 for accurate progress calculation
107//! - Status updates and progress changes are broadcast to all subscribers
108//! - Completed or cancelled tokens ignore further progress/status updates
109//! - The broadcast channel has a reasonable buffer size to prevent lagging
110
111use futures::{Stream, ready};
112use pin_project_lite::pin_project;
113use std::pin::Pin;
114use std::sync::atomic::{AtomicBool, Ordering};
115use std::sync::{Arc, Mutex};
116use std::task::{Context, Poll};
117use tokio::sync::broadcast;
118use tokio_stream::wrappers::BroadcastStream;
119use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};
120
121/// Represents either a determinate progress value or indeterminate state
122#[derive(Debug, Clone, Copy)]
123pub enum Progress {
124    Determinate(f64),
125    Indeterminate,
126}
127
128impl Progress {
129    pub fn as_f64(&self) -> Option<f64> {
130        match self {
131            Progress::Determinate(v) => Some(*v),
132            Progress::Indeterminate => None,
133        }
134    }
135}
136
137#[derive(Debug, Clone, Copy)]
138pub enum ProgressError {
139    /// Too many progress updates have occurred since last polled, so some of
140    /// them have been dropped
141    Lagged,
142    /// This progress token has been cancelled, no more updates are coming
143    Cancelled,
144}
145
146/// Data for a progress update event
147#[derive(Debug, Clone)]
148pub struct ProgressUpdate<S> {
149    pub progress: Progress,
150    pub statuses: Vec<S>,
151    pub is_cancelled: bool,
152}
153
154impl<S> ProgressUpdate<S> {
155    pub fn status(&self) -> &S {
156        self.statuses.last().unwrap()
157    }
158}
159
160/// Inner data of a progress node
161struct ProgressNodeInner<S> {
162    // Tree structure
163    parent: Option<Arc<ProgressNode<S>>>,
164    children: Vec<(Arc<ProgressNode<S>>, f64)>, // Node and its weight
165
166    // Progress state
167    progress: Progress,
168    status: S,
169    is_completed: bool,
170
171    // Subscriber management
172    update_sender: broadcast::Sender<ProgressUpdate<S>>,
173}
174
175/// A node in the progress tree
176struct ProgressNode<S> {
177    inner: Mutex<ProgressNodeInner<S>>,
178    // change_notify: Notify,
179}
180
181impl<S: Clone + Send> ProgressNode<S> {
182    fn new(status: S) -> Self {
183        // create broadcast channel with reasonable buffer size
184        let (tx, _) = broadcast::channel(16);
185
186        Self {
187            inner: Mutex::new(ProgressNodeInner {
188                parent: None,
189                children: Vec::new(),
190                progress: Progress::Determinate(0.0),
191                status,
192                is_completed: false,
193                update_sender: tx,
194            }),
195            // change_notify: Notify::new(),
196        }
197    }
198
199    fn child(parent: &Arc<Self>, weight: f64, status: S) -> Arc<Self> {
200        let mut parent_inner = parent.inner.lock().unwrap();
201
202        // create broadcast channel with reasonable buffer size
203        let (tx, _) = broadcast::channel(16);
204
205        let child = Self {
206            inner: Mutex::new(ProgressNodeInner {
207                parent: Some(parent.clone()),
208                children: Vec::new(),
209                progress: Progress::Determinate(0.0),
210                status,
211                is_completed: false,
212                update_sender: tx,
213            }),
214            // change_notify: Notify::new(),
215        };
216
217        let child = Arc::new(child);
218
219        parent_inner.children.push((child.clone(), weight));
220
221        child
222    }
223
224    fn calculate_progress(node: &Arc<Self>) -> Progress {
225        let inner = node.inner.lock().unwrap();
226
227        // If this node itself is indeterminate, propagate that
228        if matches!(inner.progress, Progress::Indeterminate) {
229            return Progress::Indeterminate;
230        }
231
232        if inner.children.is_empty() {
233            return inner.progress;
234        }
235
236        // Check if any active child is indeterminate
237        let has_indeterminate = inner
238            .children
239            .iter()
240            .filter(|(child, _)| {
241                let child_inner = child.inner.lock().unwrap();
242                !child_inner.is_completed
243            })
244            .any(|(child, _)| matches!(Self::calculate_progress(child), Progress::Indeterminate));
245
246        if has_indeterminate {
247            return Progress::Indeterminate;
248        }
249
250        // Calculate weighted average of determinate children
251        let total: f64 = inner
252            .children
253            .iter()
254            .map(|(child, weight)| {
255                match Self::calculate_progress(child) {
256                    Progress::Determinate(p) => p * weight,
257                    Progress::Indeterminate => 0.0, // Shouldn't happen due to check above
258                }
259            })
260            .sum();
261
262        Progress::Determinate(total)
263    }
264
265    fn get_status_hierarchy(node: &Arc<Self>) -> Vec<S> {
266        let inner = node.inner.lock().unwrap();
267        let mut result = vec![inner.status.clone()];
268
269        // Find active child
270        if !inner.children.is_empty() {
271            let active_child = inner
272                .children
273                .iter()
274                .filter(|(child, _)| {
275                    let child_inner = child.inner.lock().unwrap();
276                    !child_inner.is_completed
277                })
278                .next();
279
280            if let Some((child, _)) = active_child {
281                let child_statuses = Self::get_status_hierarchy(child);
282                result.extend(child_statuses);
283            }
284        }
285
286        result
287    }
288
289    fn notify_subscribers(node: &Arc<Self>, is_cancelled: bool) {
290        // Create update while holding the lock
291        let update = ProgressUpdate {
292            progress: Self::calculate_progress(node),
293            statuses: Self::get_status_hierarchy(node),
294            is_cancelled,
295        };
296
297        // Send updates without holding the lock
298        {
299            let inner = node.inner.lock().unwrap();
300            // broadcast to all subscribers, ignore send errors (no subscribers/full)
301            let _ = inner.update_sender.send(update);
302        };
303
304        // Notify waiters
305        // node.change_notify.notify_waiters();
306
307        // Propagate to parent
308        let parent = {
309            let inner = node.inner.lock().unwrap();
310            inner.parent.clone()
311        };
312
313        if let Some(parent) = parent {
314            Self::notify_subscribers(&parent, false);
315        }
316    }
317}
318
319/// A token that tracks the progress of a task and can be organized hierarchically
320#[derive(Clone)]
321pub struct ProgressToken<S> {
322    node: Arc<ProgressNode<S>>,
323    is_active: Arc<AtomicBool>,
324    cancel_token: CancellationToken,
325}
326
327impl<S: Clone + Send + 'static> ProgressToken<S> {
328    /// Create a new root ProgressToken
329    pub fn new(status: impl Into<S>) -> Arc<Self> {
330        let node = Arc::new(ProgressNode::new(status.into()));
331
332        Arc::new(Self {
333            node,
334            is_active: Arc::new(AtomicBool::new(true)),
335            cancel_token: CancellationToken::new(),
336        })
337    }
338
339    /// Create a child token
340    pub fn child(parent: &Arc<Self>, weight: f64, status: impl Into<S>) -> Arc<Self> {
341        let node = ProgressNode::child(&parent.node, weight, status.into());
342
343        Arc::new(Self {
344            node,
345            is_active: Arc::new(AtomicBool::new(true)),
346            cancel_token: parent.cancel_token.child_token(),
347        })
348    }
349
350    /// Update the progress of this token
351    pub fn progress(&self, progress: f64) {
352        if !self.is_active.load(Ordering::Relaxed) || self.cancel_token.is_cancelled() {
353            return;
354        }
355
356        // // Rate-limit updates to avoid too many notifications
357        // let now = Instant::now();
358        // let count = self.update_count.fetch_add(1, Ordering::Relaxed);
359
360        // if now.duration_since(self.last_update) > Duration::from_millis(100) || count % 10 == 0 {
361        let mut inner = self.node.inner.lock().unwrap();
362        inner.progress = Progress::Determinate(progress.max(0.0).min(1.0));
363        drop(inner);
364
365        ProgressNode::notify_subscribers(&self.node, false);
366        // }
367    }
368
369    /// Set the progress state to indeterminate
370    pub fn indeterminate(&self) {
371        if !self.is_active.load(Ordering::Relaxed) || self.cancel_token.is_cancelled() {
372            return;
373        }
374
375        let mut inner = self.node.inner.lock().unwrap();
376        inner.progress = Progress::Indeterminate;
377        drop(inner);
378
379        ProgressNode::notify_subscribers(&self.node, false);
380    }
381
382    /// Update the status message
383    pub fn status(&self, status: impl Into<S>) {
384        if !self.is_active.load(Ordering::Relaxed) || self.cancel_token.is_cancelled() {
385            return;
386        }
387
388        let mut inner = self.node.inner.lock().unwrap();
389        inner.status = status.into();
390        drop(inner);
391
392        ProgressNode::notify_subscribers(&self.node, false);
393    }
394
395    /// Mark the task as complete
396    pub fn complete(&self) {
397        if self.is_active.swap(false, Ordering::Relaxed) {
398            let mut inner = self.node.inner.lock().unwrap();
399            inner.is_completed = true;
400            inner.progress = Progress::Determinate(1.0);
401            drop(inner);
402
403            ProgressNode::notify_subscribers(&self.node, false);
404        }
405    }
406
407    /// Cancel this task and all its children
408    pub fn cancel(&self) {
409        if self.is_active.swap(false, Ordering::Relaxed) {
410            self.cancel_token.cancel();
411
412            ProgressNode::notify_subscribers(&self.node, true);
413        }
414    }
415
416    /// Check if this token has been cancelled
417    pub fn is_cancelled(&self) -> bool {
418        self.cancel_token.is_cancelled()
419    }
420
421    /// Get the current progress state asynchronously
422    pub fn state(&self) -> Progress {
423        ProgressNode::calculate_progress(&self.node)
424    }
425
426    /// Get all status messages in this hierarchy asynchronously
427    pub fn statuses(&self) -> Vec<S> {
428        ProgressNode::get_status_hierarchy(&self.node)
429    }
430
431    pub fn cancelled(&self) -> WaitForCancellationFuture {
432        self.cancel_token.cancelled()
433    }
434
435    pub async fn updated(&self) -> Result<ProgressUpdate<S>, ProgressError> {
436        let mut rx = {
437            let inner = self.node.inner.lock().unwrap();
438            inner.update_sender.subscribe()
439        };
440
441        tokio::select! {
442            _ = self.cancel_token.cancelled() => {
443                Err(ProgressError::Cancelled)
444            }
445            result = rx.recv() => {
446                match result {
447                    Ok(update) => Ok(update),
448                    Err(broadcast::error::RecvError::Closed) => Err(ProgressError::Cancelled),
449                    Err(broadcast::error::RecvError::Lagged(_)) => Err(ProgressError::Lagged),
450                }
451            }
452        }
453    }
454
455    /// Subscribe to progress updates from this token
456    pub fn subscribe(&self) -> ProgressStream<'_, S> {
457        let rx = {
458            let inner = self.node.inner.lock().unwrap();
459            inner.update_sender.subscribe()
460        };
461
462        ProgressStream {
463            token: self,
464            rx: BroadcastStream::new(rx),
465        }
466    }
467}
468
469pin_project! {
470    /// A Future that is resolved once the corresponding [`ProgressToken`]
471    /// is updated. Resolves to `None` if the progress token is cancelled.
472    #[must_use = "futures do nothing unless polled"]
473    pub struct WaitForUpdateFuture<'a, S> {
474        token: &'a ProgressToken<S>,
475        #[pin]
476        future: tokio::sync::futures::Notified<'a>,
477    }
478}
479
480impl<'a, S: Clone + Send + 'static> Future for WaitForUpdateFuture<'a, S> {
481    type Output = Option<ProgressUpdate<S>>;
482
483    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
484        let mut this = self.project();
485        if this.token.is_cancelled() {
486            return Poll::Ready(None);
487        }
488
489        ready!(this.future.as_mut().poll(cx));
490
491        Poll::Ready(Some(ProgressUpdate {
492            progress: this.token.state(),
493            statuses: this.token.statuses(),
494            is_cancelled: false,
495        }))
496    }
497}
498
499pin_project! {
500    /// A Stream that yields progress updates from a token
501    #[must_use = "streams do nothing unless polled"]
502    pub struct ProgressStream<'a, S> {
503        token: &'a ProgressToken<S>,
504        #[pin]
505        rx: BroadcastStream<ProgressUpdate<S>>,
506    }
507}
508
509impl<'a, S: Clone + Send + 'static> Stream for ProgressStream<'a, S> {
510    type Item = ProgressUpdate<S>;
511
512    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
513        self.project()
514            .rx
515            .poll_next(cx)
516            .map(|opt| opt.map(|res| res.unwrap()))
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523    use futures::StreamExt;
524    use std::time::Duration;
525    use tokio::time::sleep;
526
527    // helper function to create a test hierarchy
528    async fn create_test_hierarchy() -> (
529        Arc<ProgressToken<String>>,
530        Arc<ProgressToken<String>>,
531        Arc<ProgressToken<String>>,
532    ) {
533        let root = ProgressToken::new("root".to_string());
534        let child1 = ProgressToken::child(&root, 0.6, "child1".to_string());
535        let child2 = ProgressToken::child(&root, 0.4, "child2".to_string());
536        (root, child1, child2)
537    }
538
539    #[tokio::test]
540    async fn test_basic_progress_updates() {
541        let token = ProgressToken::new("test".to_string());
542        token.progress(0.5);
543        assert!(
544            matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
545        );
546
547        token.progress(1.0);
548        assert!(
549            matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
550        );
551
552        // test progress clamping
553        token.progress(1.5);
554        assert!(
555            matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
556        );
557
558        token.progress(-0.5);
559        assert!(matches!(token.state(), Progress::Determinate(p) if p.abs() < f64::EPSILON));
560    }
561
562    #[tokio::test]
563    async fn test_hierarchical_progress() {
564        let (root, child1, child2) = create_test_hierarchy().await;
565
566        // update children progress
567        child1.progress(0.5);
568        child2.progress(0.5);
569
570        // root progress should be weighted average: 0.5 * 0.6 + 0.5 * 0.4 = 0.5
571        assert!(matches!(root.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON));
572
573        child1.progress(1.0);
574        // root progress should now be: 1.0 * 0.6 + 0.5 * 0.4 = 0.8
575        assert!(matches!(root.state(), Progress::Determinate(p) if (p - 0.8).abs() < f64::EPSILON));
576    }
577
578    #[tokio::test]
579    async fn test_indeterminate_state() {
580        let (root, child1, child2) = create_test_hierarchy().await;
581
582        // set one child to indeterminate
583        child1.indeterminate();
584        child2.progress(0.5);
585
586        // root should be indeterminate
587        assert!(matches!(root.state(), Progress::Indeterminate));
588
589        // set child back to determinate
590        child1.progress(0.5);
591        assert!(matches!(root.state(), Progress::Determinate(_)));
592    }
593
594    #[tokio::test]
595    async fn test_status_updates() {
596        let token = ProgressToken::new("initial status".to_string());
597        let statuses = token.statuses();
598        assert_eq!(statuses, vec!["initial status".to_string()]);
599
600        token.status("updated status".to_string());
601        let statuses = token.statuses();
602        assert_eq!(statuses, vec!["updated status".to_string()]);
603    }
604
605    #[tokio::test]
606    async fn test_status_hierarchy() {
607        let (root, child1, _) = create_test_hierarchy().await;
608
609        let statuses = root.statuses();
610        assert_eq!(statuses, vec!["root".to_string(), "child1".to_string()]);
611
612        child1.status("updated child1".to_string());
613        let statuses = root.statuses();
614        assert_eq!(
615            statuses,
616            vec!["root".to_string(), "updated child1".to_string()]
617        );
618    }
619
620    #[tokio::test]
621    async fn test_cancellation() {
622        let (root, child1, child2) = create_test_hierarchy().await;
623
624        // cancel root
625        root.cancel();
626
627        assert!(root.is_cancelled());
628        assert!(child1.is_cancelled());
629        assert!(child2.is_cancelled());
630
631        // updates should not be processed after cancellation
632        child1.progress(0.5);
633        assert!(matches!(child1.state(), Progress::Determinate(p) if p.abs() < f64::EPSILON));
634    }
635
636    #[tokio::test]
637    async fn test_completion() {
638        let token = ProgressToken::new("test".to_string());
639        token.complete();
640
641        assert!(
642            matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
643        );
644
645        // updates after completion should not be processed
646        token.progress(0.5);
647        assert!(
648            matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
649        );
650    }
651
652    #[tokio::test]
653    async fn test_subscription() {
654        let token = ProgressToken::new("test".to_string());
655        let mut subscription = token.subscribe();
656
657        // initial update
658        let update = subscription.next().await.unwrap();
659        assert_eq!(update.status(), &"test".to_string());
660        assert!(matches!(update.progress, Progress::Determinate(p) if p.abs() < f64::EPSILON));
661
662        // progress update
663        token.progress(0.5);
664        let update = subscription.next().await.unwrap();
665        assert!(
666            matches!(update.progress, Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
667        );
668    }
669
670    #[tokio::test]
671    async fn test_multiple_subscribers() {
672        let token = ProgressToken::new("test".to_string());
673        let mut sub1 = token.subscribe();
674        let mut sub2 = token.subscribe();
675
676        // Skip initial updates
677        sub1.next().await.unwrap();
678        sub2.next().await.unwrap();
679
680        // both subscribers should receive updates
681        token.progress(0.5);
682
683        let update1 = sub1.next().await.unwrap();
684        let update2 = sub2.next().await.unwrap();
685
686        assert!(
687            matches!(update1.progress, Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON),
688            "{update1:?}"
689        );
690        assert!(
691            matches!(update2.progress, Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON),
692            "{update2:?}"
693        );
694    }
695
696    #[tokio::test]
697    async fn test_concurrent_updates() {
698        let token = Arc::new(ProgressToken::new("test".to_string()));
699        let mut handles = vec![];
700
701        // spawn multiple tasks updating the same token
702        for i in 0..10 {
703            let token = token.clone();
704            handles.push(tokio::spawn(async move {
705                sleep(Duration::from_millis(i * 10)).await;
706                token.progress(i as f64 / 10.0);
707            }));
708        }
709
710        // wait for all tasks to complete
711        for handle in handles {
712            handle.await.unwrap();
713        }
714
715        // final progress should be from the last update (0.9)
716        assert!(
717            matches!(token.state(), Progress::Determinate(p) if (p - 0.9).abs() < f64::EPSILON)
718        );
719    }
720
721    #[tokio::test]
722    async fn test_edge_cases() {
723        // single node tree
724        let token = ProgressToken::new("single".to_string());
725        token.progress(0.5);
726        assert!(
727            matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
728        );
729
730        // deep hierarchy
731        let mut current = ProgressToken::new("root".to_string());
732        for i in 0..10 {
733            current = ProgressToken::child(&current, 1.0, format!("child{}", i));
734        }
735
736        // update leaf node
737        current.progress(1.0);
738        // progress should propagate to root
739        assert!(
740            matches!(current.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
741        );
742    }
743}