progress_token/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use futures::{Stream, ready};
4use pin_project_lite::pin_project;
5use std::pin::Pin;
6use std::sync::{Arc, Mutex};
7use std::task::{Context, Poll};
8use thiserror::Error;
9use tokio::sync::broadcast;
10use tokio_stream::wrappers::BroadcastStream;
11use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};
12
13/// A guard that automatically marks a [`ProgressToken`] as complete when dropped
14#[must_use = "if unused, the progress token will be completed immediately"]
15pub struct CompleteGuard<'a, S: Clone + Send + 'static> {
16    token: &'a ProgressToken<S>,
17}
18
19impl<'a, S: Clone + Send + 'static> CompleteGuard<'a, S> {
20    /// Forgets the guard without completing the progress token
21    pub fn forget(self) {
22        std::mem::forget(self);
23    }
24}
25
26impl<'a, S: Clone + Send + 'static> Drop for CompleteGuard<'a, S> {
27    fn drop(&mut self) {
28        self.token.complete();
29    }
30}
31
32/// Represents either a determinate progress value or indeterminate state
33#[derive(Debug, Clone, Copy)]
34#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
35pub enum Progress {
36    Determinate(f64),
37    Indeterminate,
38}
39
40impl Progress {
41    pub fn as_f64(&self) -> Option<f64> {
42        match self {
43            Progress::Determinate(v) => Some(*v),
44            Progress::Indeterminate => None,
45        }
46    }
47}
48
49#[derive(Debug, Clone, Copy, Error)]
50pub enum ProgressError {
51    /// Too many progress updates have occurred since last polled, so some of
52    /// them have been dropped
53    #[error("progress updates lagged")]
54    Lagged,
55    /// This progress token has been cancelled, no more updates are coming
56    #[error("the operation has been cancelled")]
57    Cancelled,
58}
59
60/// Data for a progress update event
61#[derive(Debug, Clone)]
62#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
63pub struct ProgressUpdate<S> {
64    pub progress: Progress,
65    pub statuses: Vec<S>,
66    pub is_cancelled: bool,
67}
68
69impl<S> ProgressUpdate<S> {
70    pub fn status(&self) -> &S {
71        self.statuses.last().unwrap()
72    }
73}
74
75/// Inner data of a progress node
76struct ProgressNodeInner<S> {
77    // Tree structure
78    parent: Option<Arc<ProgressNode<S>>>,
79    children: Vec<(Arc<ProgressNode<S>>, f64)>, // Node and its weight
80
81    // Progress state
82    progress: Progress,
83    status: S,
84    is_completed: bool,
85
86    // Subscriber management
87    update_sender: broadcast::Sender<ProgressUpdate<S>>,
88}
89
90/// A node in the progress tree
91struct ProgressNode<S> {
92    inner: Mutex<ProgressNodeInner<S>>,
93    // change_notify: Notify,
94}
95
96impl<S: Clone + Send> ProgressNode<S> {
97    fn new(status: S) -> Self {
98        // create broadcast channel with reasonable buffer size
99        let (tx, _) = broadcast::channel(16);
100
101        Self {
102            inner: Mutex::new(ProgressNodeInner {
103                parent: None,
104                children: Vec::new(),
105                progress: Progress::Determinate(0.0),
106                status,
107                is_completed: false,
108                update_sender: tx,
109            }),
110            // change_notify: Notify::new(),
111        }
112    }
113
114    fn child(parent: &Arc<Self>, weight: f64, status: S) -> Arc<Self> {
115        let mut parent_inner = parent.inner.lock().unwrap();
116
117        // create broadcast channel with reasonable buffer size
118        let (tx, _) = broadcast::channel(16);
119
120        let child = Self {
121            inner: Mutex::new(ProgressNodeInner {
122                parent: Some(parent.clone()),
123                children: Vec::new(),
124                progress: Progress::Determinate(0.0),
125                status,
126                is_completed: false,
127                update_sender: tx,
128            }),
129            // change_notify: Notify::new(),
130        };
131
132        let child = Arc::new(child);
133
134        parent_inner.children.push((child.clone(), weight));
135
136        child
137    }
138
139    fn calculate_progress(node: &Arc<Self>) -> Progress {
140        let inner = node.inner.lock().unwrap();
141
142        // If this node itself is indeterminate, propagate that
143        if matches!(inner.progress, Progress::Indeterminate) {
144            return Progress::Indeterminate;
145        }
146
147        if inner.children.is_empty() {
148            return inner.progress;
149        }
150
151        // Check if any active child is indeterminate
152        let has_indeterminate = inner
153            .children
154            .iter()
155            .filter(|(child, _)| {
156                let child_inner = child.inner.lock().unwrap();
157                !child_inner.is_completed
158            })
159            .any(|(child, _)| matches!(Self::calculate_progress(child), Progress::Indeterminate));
160
161        if has_indeterminate {
162            return Progress::Indeterminate;
163        }
164
165        // Calculate weighted average of determinate children
166        let total: f64 = inner
167            .children
168            .iter()
169            .map(|(child, weight)| {
170                match Self::calculate_progress(child) {
171                    Progress::Determinate(p) => p * weight,
172                    Progress::Indeterminate => 0.0, // Shouldn't happen due to check above
173                }
174            })
175            .sum();
176
177        Progress::Determinate(total)
178    }
179
180    fn get_status_hierarchy(node: &Arc<Self>) -> Vec<S> {
181        let inner = node.inner.lock().unwrap();
182        let mut result = vec![inner.status.clone()];
183
184        // Find active child
185        if !inner.children.is_empty() {
186            let active_child = inner
187                .children
188                .iter()
189                .filter(|(child, _)| {
190                    let child_inner = child.inner.lock().unwrap();
191                    !child_inner.is_completed
192                })
193                .next();
194
195            if let Some((child, _)) = active_child {
196                let child_statuses = Self::get_status_hierarchy(child);
197                result.extend(child_statuses);
198            }
199        }
200
201        result
202    }
203
204    fn notify_subscribers(node: &Arc<Self>, is_cancelled: bool) {
205        // Create update while holding the lock
206        let update = ProgressUpdate {
207            progress: Self::calculate_progress(node),
208            statuses: Self::get_status_hierarchy(node),
209            is_cancelled,
210        };
211
212        // Send updates without holding the lock
213        {
214            let inner = node.inner.lock().unwrap();
215            // broadcast to all subscribers, ignore send errors (no subscribers/full)
216            let _ = inner.update_sender.send(update);
217        };
218
219        // Notify waiters
220        // node.change_notify.notify_waiters();
221
222        // Propagate to parent
223        let parent = {
224            let inner = node.inner.lock().unwrap();
225            inner.parent.clone()
226        };
227
228        if let Some(parent) = parent {
229            Self::notify_subscribers(&parent, false);
230        }
231    }
232}
233
234/// A token that tracks the progress of a task and can be organized hierarchically
235#[derive(Clone)]
236pub struct ProgressToken<S> {
237    node: Arc<ProgressNode<S>>,
238    cancel_token: CancellationToken,
239}
240
241impl<S: std::fmt::Debug> std::fmt::Debug for ProgressToken<S> {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        f.debug_struct("ProgressToken")
244            .field("is_cancelled", &self.cancel_token.is_cancelled())
245            .finish()
246    }
247}
248
249impl<S: Clone + Send + 'static> ProgressToken<S> {
250    /// Create a new root ProgressToken
251    pub fn new(status: impl Into<S>) -> Self {
252        let node = Arc::new(ProgressNode::new(status.into()));
253
254        Self {
255            node,
256            cancel_token: CancellationToken::new(),
257        }
258    }
259
260    /// Create a child token
261    pub fn child(&self, weight: f64, status: impl Into<S>) -> Self {
262        let node = ProgressNode::child(&self.node, weight, status.into());
263
264        Self {
265            node,
266            cancel_token: self.cancel_token.child_token(),
267        }
268    }
269
270    /// Update the progress of this token
271    pub fn update_progress(&self, progress: f64) {
272        if self.is_cancelled() {
273            return;
274        }
275
276        let is_completed = {
277            let inner = self.node.inner.lock().unwrap();
278            inner.is_completed
279        };
280
281        if is_completed {
282            return;
283        }
284
285        let mut inner = self.node.inner.lock().unwrap();
286        inner.progress = Progress::Determinate(progress.max(0.0).min(1.0));
287        drop(inner);
288
289        ProgressNode::notify_subscribers(&self.node, false);
290    }
291
292    /// Set the progress state to indeterminate
293    pub fn update_indeterminate(&self) {
294        if self.is_cancelled() {
295            return;
296        }
297
298        let mut inner = self.node.inner.lock().unwrap();
299        if inner.is_completed {
300            return;
301        }
302
303        inner.progress = Progress::Indeterminate;
304        drop(inner);
305
306        ProgressNode::notify_subscribers(&self.node, false);
307    }
308
309    /// Update the status message
310    pub fn update_status(&self, status: impl Into<S>) {
311        if self.is_cancelled() {
312            return;
313        }
314
315        let mut inner = self.node.inner.lock().unwrap();
316        if inner.is_completed {
317            return;
318        }
319
320        inner.status = status.into();
321        drop(inner);
322
323        ProgressNode::notify_subscribers(&self.node, false);
324    }
325
326    /// Update the progress and status message
327    pub fn update(&self, progress: Progress, status: impl Into<S>) {
328        if self.is_cancelled() {
329            return;
330        }
331
332        let mut inner = self.node.inner.lock().unwrap();
333        if inner.is_completed {
334            return;
335        }
336
337        inner.status = status.into();
338        inner.progress = progress;
339        drop(inner);
340
341        ProgressNode::notify_subscribers(&self.node, false);
342    }
343
344    /// Mark the task as complete
345    pub fn complete(&self) {
346        if self.is_cancelled() {
347            return;
348        }
349
350        let mut inner = self.node.inner.lock().unwrap();
351        if !inner.is_completed {
352            inner.is_completed = true;
353            inner.progress = Progress::Determinate(1.0);
354            drop(inner);
355
356            ProgressNode::notify_subscribers(&self.node, false);
357        }
358    }
359
360    /// Returns ProgressError::Cancelled if the token is cancelled, otherwise Ok.
361    pub fn check(&self) -> Result<(), ProgressError> {
362        if self.is_cancelled() {
363            Err(ProgressError::Cancelled)
364        } else {
365            Ok(())
366        }
367    }
368
369    pub fn is_cancelled(&self) -> bool {
370        self.cancel_token.is_cancelled()
371    }
372
373    /// Cancel this task and all its children
374    pub fn cancel(&self) {
375        if !self.cancel_token.is_cancelled() {
376            self.cancel_token.cancel();
377            ProgressNode::notify_subscribers(&self.node, true);
378        }
379    }
380
381    /// Get the current progress state asynchronously
382    pub fn state(&self) -> Progress {
383        ProgressNode::calculate_progress(&self.node)
384    }
385
386    /// Get all status messages in this hierarchy asynchronously
387    pub fn statuses(&self) -> Vec<S> {
388        ProgressNode::get_status_hierarchy(&self.node)
389    }
390
391    pub fn cancelled(&self) -> WaitForCancellationFuture {
392        self.cancel_token.cancelled()
393    }
394
395    pub async fn updated(&self) -> Result<ProgressUpdate<S>, ProgressError> {
396        let mut rx = {
397            let inner = self.node.inner.lock().unwrap();
398            inner.update_sender.subscribe()
399        };
400
401        tokio::select! {
402            _ = self.cancel_token.cancelled() => {
403                Err(ProgressError::Cancelled)
404            }
405            result = rx.recv() => {
406                match result {
407                    Ok(update) => Ok(update),
408                    Err(broadcast::error::RecvError::Closed) => Err(ProgressError::Cancelled),
409                    Err(broadcast::error::RecvError::Lagged(_)) => Err(ProgressError::Lagged),
410                }
411            }
412        }
413    }
414
415    /// Subscribe to progress updates from this token
416    pub fn subscribe(&self) -> ProgressStream<'_, S> {
417        let rx = {
418            let inner = self.node.inner.lock().unwrap();
419            inner.update_sender.subscribe()
420        };
421
422        ProgressStream {
423            token: self,
424            rx: BroadcastStream::new(rx),
425        }
426    }
427
428    /// Creates a guard that will automatically mark this token as complete when dropped
429    pub fn complete_guard(&self) -> CompleteGuard<'_, S> {
430        CompleteGuard { token: self }
431    }
432}
433
434pin_project! {
435    /// A Future that is resolved once the corresponding [`ProgressToken`]
436    /// is updated. Resolves to `None` if the progress token is cancelled.
437    #[must_use = "futures do nothing unless polled"]
438    pub struct WaitForUpdateFuture<'a, S> {
439        token: &'a ProgressToken<S>,
440        #[pin]
441        future: tokio::sync::futures::Notified<'a>,
442    }
443}
444
445impl<'a, S: Clone + Send + 'static> Future for WaitForUpdateFuture<'a, S> {
446    type Output = Option<ProgressUpdate<S>>;
447
448    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
449        let mut this = self.project();
450        if this.token.cancel_token.is_cancelled() {
451            return Poll::Ready(None);
452        }
453
454        ready!(this.future.as_mut().poll(cx));
455
456        Poll::Ready(Some(ProgressUpdate {
457            progress: this.token.state(),
458            statuses: this.token.statuses(),
459            is_cancelled: false,
460        }))
461    }
462}
463
464pin_project! {
465    /// A Stream that yields progress updates from a token
466    #[must_use = "streams do nothing unless polled"]
467    pub struct ProgressStream<'a, S> {
468        token: &'a ProgressToken<S>,
469        #[pin]
470        rx: BroadcastStream<ProgressUpdate<S>>,
471    }
472}
473
474impl<'a, S: Clone + Send + 'static> Stream for ProgressStream<'a, S> {
475    type Item = ProgressUpdate<S>;
476
477    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
478        self.project()
479            .rx
480            .poll_next(cx)
481            .map(|opt| opt.map(|res| res.unwrap()))
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use futures::StreamExt;
489    use std::time::Duration;
490    use tokio::time::sleep;
491
492    // helper function to create a test hierarchy
493    async fn create_test_hierarchy() -> (
494        ProgressToken<String>,
495        ProgressToken<String>,
496        ProgressToken<String>,
497    ) {
498        let root = ProgressToken::new("root".to_string());
499        let child1 = root.child(0.6, "child1".to_string());
500        let child2 = root.child(0.4, "child2".to_string());
501        (root, child1, child2)
502    }
503
504    #[tokio::test]
505    async fn test_basic_progress_updates() {
506        let token: ProgressToken<String> = ProgressToken::new("test".to_string());
507        token.update_progress(0.5);
508        assert!(
509            matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
510        );
511
512        token.update_progress(1.0);
513        assert!(
514            matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
515        );
516
517        // test progress clamping
518        token.update_progress(1.5);
519        assert!(
520            matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
521        );
522
523        token.update_progress(-0.5);
524        assert!(matches!(token.state(), Progress::Determinate(p) if p.abs() < f64::EPSILON));
525    }
526
527    #[tokio::test]
528    async fn test_hierarchical_progress() {
529        let (root, child1, child2) = create_test_hierarchy().await;
530
531        // update children progress
532        child1.update_progress(0.5);
533        child2.update_progress(0.5);
534
535        // root progress should be weighted average: 0.5 * 0.6 + 0.5 * 0.4 = 0.5
536        assert!(matches!(root.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON));
537
538        child1.update_progress(1.0);
539        // root progress should now be: 1.0 * 0.6 + 0.5 * 0.4 = 0.8
540        assert!(matches!(root.state(), Progress::Determinate(p) if (p - 0.8).abs() < f64::EPSILON));
541    }
542
543    #[tokio::test]
544    async fn test_indeterminate_state() {
545        let (root, child1, child2) = create_test_hierarchy().await;
546
547        // set one child to indeterminate
548        child1.update_indeterminate();
549        child2.update_progress(0.5);
550
551        // root should be indeterminate
552        assert!(matches!(root.state(), Progress::Indeterminate));
553
554        // set child back to determinate
555        child1.update_progress(0.5);
556        assert!(matches!(root.state(), Progress::Determinate(_)));
557    }
558
559    #[tokio::test]
560    async fn test_status_updates() {
561        let token: ProgressToken<String> = ProgressToken::new("initial status".to_string());
562        let statuses = token.statuses();
563        assert_eq!(statuses, vec!["initial status".to_string()]);
564
565        token.update_status("updated status".to_string());
566        let statuses = token.statuses();
567        assert_eq!(statuses, vec!["updated status".to_string()]);
568    }
569
570    #[tokio::test]
571    async fn test_status_hierarchy() {
572        let (root, child1, _) = create_test_hierarchy().await;
573
574        let statuses = root.statuses();
575        assert_eq!(statuses, vec!["root".to_string(), "child1".to_string()]);
576
577        child1.update_status("updated child1".to_string());
578        let statuses = root.statuses();
579        assert_eq!(
580            statuses,
581            vec!["root".to_string(), "updated child1".to_string()]
582        );
583    }
584
585    #[tokio::test]
586    async fn test_cancellation() {
587        let (root, child1, child2) = create_test_hierarchy().await;
588
589        // cancel root
590        root.cancel();
591
592        assert!(root.cancel_token.is_cancelled());
593        assert!(child1.cancel_token.is_cancelled());
594        assert!(child2.cancel_token.is_cancelled());
595
596        // updates should not be processed after cancellation
597        child1.update_progress(0.5);
598        assert!(matches!(child1.state(), Progress::Determinate(p) if p.abs() < f64::EPSILON));
599    }
600
601    #[tokio::test]
602    async fn test_complete_guard() {
603        let token: ProgressToken<String> = ProgressToken::new("test".to_string());
604
605        {
606            let _guard = token.complete_guard();
607            token.update_progress(0.5);
608            assert!(
609                matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
610            );
611        } // guard is dropped here, token should be completed
612
613        // token should be completed and at progress 1.0
614        assert!(
615            matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
616        );
617
618        // updates after completion should be ignored
619        token.update_progress(0.5);
620        assert!(
621            matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
622        );
623
624        // test forget
625        let token: ProgressToken<String> = ProgressToken::new("test2".to_string());
626        {
627            let guard = token.complete_guard();
628            token.update_progress(0.5);
629            guard.forget(); // prevent completion
630        }
631
632        // token should still be at 0.5 since guard was forgotten
633        assert!(
634            matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
635        );
636    }
637
638    #[tokio::test]
639    async fn test_subscription() {
640        let token: ProgressToken<String> = ProgressToken::new("test".to_string());
641        let mut subscription = token.subscribe();
642
643        // initial update
644        let update = subscription.next().await.unwrap();
645        assert_eq!(update.status(), &"test".to_string());
646        assert!(matches!(update.progress, Progress::Determinate(p) if p.abs() < f64::EPSILON));
647
648        // progress update
649        token.update_progress(0.5);
650        let update = subscription.next().await.unwrap();
651        assert!(
652            matches!(update.progress, Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
653        );
654    }
655
656    #[tokio::test]
657    async fn test_multiple_subscribers() {
658        let token: ProgressToken<String> = ProgressToken::new("test".to_string());
659        let mut sub1 = token.subscribe();
660        let mut sub2 = token.subscribe();
661
662        // Skip initial updates
663        sub1.next().await.unwrap();
664        sub2.next().await.unwrap();
665
666        // both subscribers should receive updates
667        token.update_progress(0.5);
668
669        let update1 = sub1.next().await.unwrap();
670        let update2 = sub2.next().await.unwrap();
671
672        assert!(
673            matches!(update1.progress, Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON),
674            "{update1:?}"
675        );
676        assert!(
677            matches!(update2.progress, Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON),
678            "{update2:?}"
679        );
680    }
681
682    #[tokio::test]
683    async fn test_concurrent_updates() {
684        let token: ProgressToken<String> = ProgressToken::new("test".to_string());
685        let mut handles = vec![];
686
687        // spawn multiple tasks updating the same token
688        for i in 0..10 {
689            let token = token.clone();
690            handles.push(tokio::spawn(async move {
691                sleep(Duration::from_millis(i * 10)).await;
692                token.update_progress(i as f64 / 10.0);
693            }));
694        }
695
696        // wait for all tasks to complete
697        for handle in handles {
698            handle.await.unwrap();
699        }
700
701        // final progress should be from the last update (0.9)
702        assert!(
703            matches!(token.state(), Progress::Determinate(p) if (p - 0.9).abs() < f64::EPSILON)
704        );
705    }
706
707    #[tokio::test]
708    async fn test_edge_cases() {
709        // single node tree
710        let token: ProgressToken<String> = ProgressToken::new("single".to_string());
711        token.update_progress(0.5);
712        assert!(
713            matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
714        );
715
716        // deep hierarchy
717        let mut current: ProgressToken<String> = ProgressToken::new("root".to_string());
718        for i in 0..10 {
719            current = current.child(1.0, format!("child{}", i));
720        }
721
722        // update leaf node
723        current.update_progress(1.0);
724        // progress should propagate to root
725        assert!(
726            matches!(current.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
727        );
728    }
729}