progress_token/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use futures::{Stream, ready};
4use pin_project_lite::pin_project;
5use std::pin::Pin;
6use std::sync::{Arc, Mutex};
7use std::task::{Context, Poll};
8use thiserror::Error;
9use tokio::sync::broadcast;
10use tokio_stream::wrappers::BroadcastStream;
11use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
12use tokio_util::sync::{
13    CancellationToken, WaitForCancellationFuture, WaitForCancellationFutureOwned,
14};
15
16/// A guard that automatically marks a [`ProgressToken`] as complete when dropped
17#[must_use = "if unused, the progress token will be completed immediately"]
18pub struct CompleteGuard<'a, S: Clone + Send + 'static> {
19    token: &'a ProgressToken<S>,
20}
21
22impl<'a, S: Clone + Send + 'static> CompleteGuard<'a, S> {
23    /// Forgets the guard without completing the progress token
24    pub fn forget(self) {
25        std::mem::forget(self);
26    }
27}
28
29impl<'a, S: Clone + Send + 'static> Drop for CompleteGuard<'a, S> {
30    fn drop(&mut self) {
31        self.token.complete();
32    }
33}
34
35/// Represents either a determinate progress value or indeterminate state
36#[derive(Debug, Clone, Copy)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38pub enum Progress {
39    Determinate(f64),
40    Indeterminate,
41}
42
43impl Progress {
44    pub fn as_f64(&self) -> Option<f64> {
45        match self {
46            Progress::Determinate(v) => Some(*v),
47            Progress::Indeterminate => None,
48        }
49    }
50}
51
52#[derive(Debug, Clone, Copy, Error)]
53pub enum ProgressError {
54    /// Too many progress updates have occurred since last polled, so some of
55    /// them have been dropped
56    #[error("progress updates lagged")]
57    Lagged,
58    /// This progress token has been cancelled, no more updates are coming
59    #[error("the operation has been cancelled")]
60    Cancelled,
61}
62
63/// Data for a progress update event
64#[derive(Debug, Clone)]
65#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
66pub struct ProgressUpdate<S> {
67    pub progress: Progress,
68    pub statuses: Vec<S>,
69    pub is_cancelled: bool,
70}
71
72impl<S> ProgressUpdate<S> {
73    pub fn status(&self) -> &S {
74        self.statuses.last().unwrap()
75    }
76}
77
78/// Inner data of a progress node
79struct ProgressNodeInner<S> {
80    // Tree structure
81    parent: Option<Arc<ProgressNode<S>>>,
82    children: Vec<(Arc<ProgressNode<S>>, f64)>, // Node and its weight
83
84    // Progress state
85    progress: Progress,
86    status: S,
87    is_completed: bool,
88
89    // Subscriber management
90    update_sender: broadcast::Sender<ProgressUpdate<S>>,
91}
92
93/// A node in the progress tree
94struct ProgressNode<S> {
95    inner: Mutex<ProgressNodeInner<S>>,
96    // change_notify: Notify,
97}
98
99impl<S: Clone + Send> ProgressNode<S> {
100    fn new(status: S) -> Self {
101        // create broadcast channel with reasonable buffer size
102        let (tx, _) = broadcast::channel(16);
103
104        Self {
105            inner: Mutex::new(ProgressNodeInner {
106                parent: None,
107                children: Vec::new(),
108                progress: Progress::Determinate(0.0),
109                status,
110                is_completed: false,
111                update_sender: tx,
112            }),
113            // change_notify: Notify::new(),
114        }
115    }
116
117    fn child(parent: &Arc<Self>, weight: f64, status: S) -> Arc<Self> {
118        let mut parent_inner = parent.inner.lock().unwrap();
119
120        // create broadcast channel with reasonable buffer size
121        let (tx, _) = broadcast::channel(16);
122
123        let child = Self {
124            inner: Mutex::new(ProgressNodeInner {
125                parent: Some(parent.clone()),
126                children: Vec::new(),
127                progress: Progress::Determinate(0.0),
128                status,
129                is_completed: false,
130                update_sender: tx,
131            }),
132            // change_notify: Notify::new(),
133        };
134
135        let child = Arc::new(child);
136
137        parent_inner.children.push((child.clone(), weight));
138
139        child
140    }
141
142    fn calculate_progress(node: &Arc<Self>) -> Progress {
143        let inner = node.inner.lock().unwrap();
144
145        // If this node itself is indeterminate, propagate that
146        if matches!(inner.progress, Progress::Indeterminate) {
147            return Progress::Indeterminate;
148        }
149
150        if inner.children.is_empty() {
151            return inner.progress;
152        }
153
154        // Check if any active child is indeterminate
155        let has_indeterminate = inner
156            .children
157            .iter()
158            .filter(|(child, _)| {
159                let child_inner = child.inner.lock().unwrap();
160                !child_inner.is_completed
161            })
162            .any(|(child, _)| matches!(Self::calculate_progress(child), Progress::Indeterminate));
163
164        if has_indeterminate {
165            return Progress::Indeterminate;
166        }
167
168        // Calculate weighted average of determinate children
169        let total: f64 = inner
170            .children
171            .iter()
172            .map(|(child, weight)| {
173                match Self::calculate_progress(child) {
174                    Progress::Determinate(p) => p * weight,
175                    Progress::Indeterminate => 0.0, // Shouldn't happen due to check above
176                }
177            })
178            .sum();
179
180        Progress::Determinate(total)
181    }
182
183    fn get_status_hierarchy(node: &Arc<Self>) -> Vec<S> {
184        let inner = node.inner.lock().unwrap();
185        let mut result = vec![inner.status.clone()];
186
187        // Find active child
188        if !inner.children.is_empty() {
189            let active_child = inner
190                .children
191                .iter()
192                .filter(|(child, _)| {
193                    let child_inner = child.inner.lock().unwrap();
194                    !child_inner.is_completed
195                })
196                .next();
197
198            if let Some((child, _)) = active_child {
199                let child_statuses = Self::get_status_hierarchy(child);
200                result.extend(child_statuses);
201            }
202        }
203
204        result
205    }
206
207    fn notify_subscribers(node: &Arc<Self>, is_cancelled: bool) {
208        // Create update while holding the lock
209        let update = ProgressUpdate {
210            progress: Self::calculate_progress(node),
211            statuses: Self::get_status_hierarchy(node),
212            is_cancelled,
213        };
214
215        // Send updates without holding the lock
216        {
217            let inner = node.inner.lock().unwrap();
218            // broadcast to all subscribers, ignore send errors (no subscribers/full)
219            let _ = inner.update_sender.send(update);
220        };
221
222        // Notify waiters
223        // node.change_notify.notify_waiters();
224
225        // Propagate to parent
226        let parent = {
227            let inner = node.inner.lock().unwrap();
228            inner.parent.clone()
229        };
230
231        if let Some(parent) = parent {
232            Self::notify_subscribers(&parent, false);
233        }
234    }
235}
236
237/// A token that tracks the progress of a task and can be organized hierarchically
238#[derive(Clone)]
239pub struct ProgressToken<S> {
240    node: Arc<ProgressNode<S>>,
241    cancel_token: CancellationToken,
242}
243
244impl<S: Default + Clone + Send + 'static> Default for ProgressToken<S> {
245    fn default() -> Self {
246        Self::new(S::default())
247    }
248}
249
250impl<S: std::fmt::Debug> std::fmt::Debug for ProgressToken<S> {
251    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252        f.debug_struct("ProgressToken")
253            .field("is_cancelled", &self.cancel_token.is_cancelled())
254            .finish()
255    }
256}
257
258impl<S: Clone + Send + 'static> ProgressToken<S> {
259    /// Create a new root ProgressToken
260    pub fn new(status: impl Into<S>) -> Self {
261        let node = Arc::new(ProgressNode::new(status.into()));
262
263        Self {
264            node,
265            cancel_token: CancellationToken::new(),
266        }
267    }
268
269    /// Create a child token
270    pub fn child(&self, weight: f64, status: impl Into<S>) -> Self {
271        let node = ProgressNode::child(&self.node, weight, status.into());
272
273        Self {
274            node,
275            cancel_token: self.cancel_token.child_token(),
276        }
277    }
278
279    /// Update the progress of this token
280    pub fn update_progress(&self, progress: f64) {
281        if self.is_cancelled() {
282            return;
283        }
284
285        let is_completed = {
286            let inner = self.node.inner.lock().unwrap();
287            inner.is_completed
288        };
289
290        if is_completed {
291            return;
292        }
293
294        let mut inner = self.node.inner.lock().unwrap();
295        inner.progress = Progress::Determinate(progress.max(0.0).min(1.0));
296        drop(inner);
297
298        ProgressNode::notify_subscribers(&self.node, false);
299    }
300
301    /// Set the progress state to indeterminate
302    pub fn update_indeterminate(&self) {
303        if self.is_cancelled() {
304            return;
305        }
306
307        let mut inner = self.node.inner.lock().unwrap();
308        if inner.is_completed {
309            return;
310        }
311
312        inner.progress = Progress::Indeterminate;
313        drop(inner);
314
315        ProgressNode::notify_subscribers(&self.node, false);
316    }
317
318    /// Update the status message
319    pub fn update_status(&self, status: impl Into<S>) {
320        if self.is_cancelled() {
321            return;
322        }
323
324        let mut inner = self.node.inner.lock().unwrap();
325        if inner.is_completed {
326            return;
327        }
328
329        inner.status = status.into();
330        drop(inner);
331
332        ProgressNode::notify_subscribers(&self.node, false);
333    }
334
335    /// Update the progress and status message
336    pub fn update(&self, progress: Progress, status: impl Into<S>) {
337        if self.is_cancelled() {
338            return;
339        }
340
341        let mut inner = self.node.inner.lock().unwrap();
342        if inner.is_completed {
343            return;
344        }
345
346        inner.status = status.into();
347        inner.progress = progress;
348        drop(inner);
349
350        ProgressNode::notify_subscribers(&self.node, false);
351    }
352
353    /// Mark the task as complete
354    pub fn complete(&self) {
355        if self.is_cancelled() {
356            return;
357        }
358
359        let mut inner = self.node.inner.lock().unwrap();
360        if !inner.is_completed {
361            inner.is_completed = true;
362            inner.progress = Progress::Determinate(1.0);
363            drop(inner);
364
365            ProgressNode::notify_subscribers(&self.node, false);
366        }
367    }
368
369    /// Returns ProgressError::Cancelled if the token is cancelled, otherwise Ok.
370    pub fn check(&self) -> Result<(), ProgressError> {
371        if self.is_cancelled() {
372            Err(ProgressError::Cancelled)
373        } else {
374            Ok(())
375        }
376    }
377
378    pub fn is_cancelled(&self) -> bool {
379        self.cancel_token.is_cancelled()
380    }
381
382    /// Cancel this task and all its children
383    pub fn cancel(&self) {
384        if !self.cancel_token.is_cancelled() {
385            self.cancel_token.cancel();
386            ProgressNode::notify_subscribers(&self.node, true);
387        }
388    }
389
390    /// Get the current progress state asynchronously
391    pub fn state(&self) -> Progress {
392        ProgressNode::calculate_progress(&self.node)
393    }
394
395    /// Get all status messages in this hierarchy asynchronously
396    pub fn statuses(&self) -> Vec<S> {
397        ProgressNode::get_status_hierarchy(&self.node)
398    }
399
400    pub fn cancelled(&self) -> WaitForCancellationFuture {
401        self.cancel_token.cancelled()
402    }
403
404    pub fn cancelled_owned(self) -> WaitForCancellationFutureOwned {
405        self.cancel_token.cancelled_owned()
406    }
407
408    pub async fn updated(&self) -> Result<ProgressUpdate<S>, ProgressError> {
409        let mut rx = {
410            let inner = self.node.inner.lock().unwrap();
411            inner.update_sender.subscribe()
412        };
413
414        tokio::select! {
415            _ = self.cancel_token.cancelled() => {
416                Err(ProgressError::Cancelled)
417            }
418            result = rx.recv() => {
419                match result {
420                    Ok(update) => Ok(update),
421                    Err(broadcast::error::RecvError::Closed) => Err(ProgressError::Cancelled),
422                    Err(broadcast::error::RecvError::Lagged(_)) => Err(ProgressError::Lagged),
423                }
424            }
425        }
426    }
427
428    /// Subscribe to progress updates from this token
429    pub fn subscribe(&self) -> ProgressStream<'_, S> {
430        let rx = {
431            let inner = self.node.inner.lock().unwrap();
432            inner.update_sender.subscribe()
433        };
434
435        ProgressStream {
436            token: self,
437            rx: BroadcastStream::new(rx),
438        }
439    }
440
441    /// Creates a guard that will automatically mark this token as complete when dropped
442    pub fn complete_guard(&self) -> CompleteGuard<'_, S> {
443        CompleteGuard { token: self }
444    }
445}
446
447pin_project! {
448    /// A Future that is resolved once the corresponding [`ProgressToken`]
449    /// is updated. Resolves to `None` if the progress token is cancelled.
450    #[must_use = "futures do nothing unless polled"]
451    pub struct WaitForUpdateFuture<'a, S> {
452        token: &'a ProgressToken<S>,
453        #[pin]
454        future: tokio::sync::futures::Notified<'a>,
455    }
456}
457
458impl<'a, S: Clone + Send + 'static> Future for WaitForUpdateFuture<'a, S> {
459    type Output = Option<ProgressUpdate<S>>;
460
461    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
462        let mut this = self.project();
463        if this.token.cancel_token.is_cancelled() {
464            return Poll::Ready(None);
465        }
466
467        ready!(this.future.as_mut().poll(cx));
468
469        Poll::Ready(Some(ProgressUpdate {
470            progress: this.token.state(),
471            statuses: this.token.statuses(),
472            is_cancelled: false,
473        }))
474    }
475}
476
477pin_project! {
478    /// A Stream that yields progress updates from a token
479    #[must_use = "streams do nothing unless polled"]
480    pub struct ProgressStream<'a, S> {
481        token: &'a ProgressToken<S>,
482        #[pin]
483        rx: BroadcastStream<ProgressUpdate<S>>,
484    }
485}
486
487impl<'a, S: Clone + Send + 'static> Stream for ProgressStream<'a, S> {
488    type Item = Result<ProgressUpdate<S>, ProgressError>;
489
490    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
491        self.project().rx.poll_next(cx).map(|opt| {
492            opt.and_then(|res| match res {
493                Ok(update) => Some(Ok(update)),
494                Err(BroadcastStreamRecvError::Lagged(_)) => Some(Err(ProgressError::Lagged)),
495            })
496        })
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use futures::StreamExt;
504    use std::time::Duration;
505    use tokio::time::sleep;
506
507    // helper function to create a test hierarchy
508    async fn create_test_hierarchy() -> (
509        ProgressToken<String>,
510        ProgressToken<String>,
511        ProgressToken<String>,
512    ) {
513        let root = ProgressToken::new("root".to_string());
514        let child1 = root.child(0.6, "child1".to_string());
515        let child2 = root.child(0.4, "child2".to_string());
516        (root, child1, child2)
517    }
518
519    #[tokio::test]
520    async fn test_basic_progress_updates() {
521        let token: ProgressToken<String> = ProgressToken::new("test".to_string());
522        token.update_progress(0.5);
523        assert!(
524            matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
525        );
526
527        token.update_progress(1.0);
528        assert!(
529            matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
530        );
531
532        // test progress clamping
533        token.update_progress(1.5);
534        assert!(
535            matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
536        );
537
538        token.update_progress(-0.5);
539        assert!(matches!(token.state(), Progress::Determinate(p) if p.abs() < f64::EPSILON));
540    }
541
542    #[tokio::test]
543    async fn test_hierarchical_progress() {
544        let (root, child1, child2) = create_test_hierarchy().await;
545
546        // update children progress
547        child1.update_progress(0.5);
548        child2.update_progress(0.5);
549
550        // root progress should be weighted average: 0.5 * 0.6 + 0.5 * 0.4 = 0.5
551        assert!(matches!(root.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON));
552
553        child1.update_progress(1.0);
554        // root progress should now be: 1.0 * 0.6 + 0.5 * 0.4 = 0.8
555        assert!(matches!(root.state(), Progress::Determinate(p) if (p - 0.8).abs() < f64::EPSILON));
556    }
557
558    #[tokio::test]
559    async fn test_indeterminate_state() {
560        let (root, child1, child2) = create_test_hierarchy().await;
561
562        // set one child to indeterminate
563        child1.update_indeterminate();
564        child2.update_progress(0.5);
565
566        // root should be indeterminate
567        assert!(matches!(root.state(), Progress::Indeterminate));
568
569        // set child back to determinate
570        child1.update_progress(0.5);
571        assert!(matches!(root.state(), Progress::Determinate(_)));
572    }
573
574    #[tokio::test]
575    async fn test_status_updates() {
576        let token: ProgressToken<String> = ProgressToken::new("initial status".to_string());
577        let statuses = token.statuses();
578        assert_eq!(statuses, vec!["initial status".to_string()]);
579
580        token.update_status("updated status".to_string());
581        let statuses = token.statuses();
582        assert_eq!(statuses, vec!["updated status".to_string()]);
583    }
584
585    #[tokio::test]
586    async fn test_status_hierarchy() {
587        let (root, child1, _) = create_test_hierarchy().await;
588
589        let statuses = root.statuses();
590        assert_eq!(statuses, vec!["root".to_string(), "child1".to_string()]);
591
592        child1.update_status("updated child1".to_string());
593        let statuses = root.statuses();
594        assert_eq!(
595            statuses,
596            vec!["root".to_string(), "updated child1".to_string()]
597        );
598    }
599
600    #[tokio::test]
601    async fn test_cancellation() {
602        let (root, child1, child2) = create_test_hierarchy().await;
603
604        // cancel root
605        root.cancel();
606
607        assert!(root.cancel_token.is_cancelled());
608        assert!(child1.cancel_token.is_cancelled());
609        assert!(child2.cancel_token.is_cancelled());
610
611        // updates should not be processed after cancellation
612        child1.update_progress(0.5);
613        assert!(matches!(child1.state(), Progress::Determinate(p) if p.abs() < f64::EPSILON));
614    }
615
616    #[tokio::test]
617    async fn test_complete_guard() {
618        let token: ProgressToken<String> = ProgressToken::new("test".to_string());
619
620        {
621            let _guard = token.complete_guard();
622            token.update_progress(0.5);
623            assert!(
624                matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
625            );
626        } // guard is dropped here, token should be completed
627
628        // token should be completed and at progress 1.0
629        assert!(
630            matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
631        );
632
633        // updates after completion should be ignored
634        token.update_progress(0.5);
635        assert!(
636            matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
637        );
638
639        // test forget
640        let token: ProgressToken<String> = ProgressToken::new("test2".to_string());
641        {
642            let guard = token.complete_guard();
643            token.update_progress(0.5);
644            guard.forget(); // prevent completion
645        }
646
647        // token should still be at 0.5 since guard was forgotten
648        assert!(
649            matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
650        );
651    }
652
653    #[tokio::test]
654    async fn test_subscription() {
655        let token: ProgressToken<String> = ProgressToken::new("test".to_string());
656        let mut subscription = token.subscribe();
657
658        // // initial update
659        // let update = subscription.next().await.unwrap();
660        // assert_eq!(update.status(), &"test".to_string());
661        // assert!(matches!(update.progress, Progress::Determinate(p) if p.abs() < f64::EPSILON));
662
663        // progress update
664        token.update_progress(0.5);
665        let update = subscription.next().await.unwrap().unwrap();
666        assert!(
667            matches!(update.progress, Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
668        );
669    }
670
671    #[tokio::test]
672    async fn test_multiple_subscribers() {
673        let token: ProgressToken<String> = ProgressToken::new("test".to_string());
674        let mut sub1 = token.subscribe();
675        let mut sub2 = token.subscribe();
676
677        // both subscribers should receive updates
678        token.update_progress(0.5);
679
680        let update1 = sub1.next().await.unwrap().unwrap();
681        let update2 = sub2.next().await.unwrap().unwrap();
682
683        assert!(
684            matches!(update1.progress, Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON),
685            "{update1:?}"
686        );
687        assert!(
688            matches!(update2.progress, Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON),
689            "{update2:?}"
690        );
691
692        // test that both subscribers receive subsequent updates
693        token.update_progress(0.75);
694
695        let update1 = sub1.next().await.unwrap().unwrap();
696        let update2 = sub2.next().await.unwrap().unwrap();
697
698        assert!(
699            matches!(update1.progress, Progress::Determinate(p) if (p - 0.75).abs() < f64::EPSILON),
700            "{update1:?}"
701        );
702        assert!(
703            matches!(update2.progress, Progress::Determinate(p) if (p - 0.75).abs() < f64::EPSILON),
704            "{update2:?}"
705        );
706    }
707
708    #[tokio::test]
709    async fn test_concurrent_updates() {
710        let token: ProgressToken<String> = ProgressToken::new("test".to_string());
711        let mut handles = vec![];
712
713        // spawn multiple tasks updating the same token
714        for i in 0..10 {
715            let token = token.clone();
716            handles.push(tokio::spawn(async move {
717                sleep(Duration::from_millis(i * 10)).await;
718                token.update_progress(i as f64 / 10.0);
719            }));
720        }
721
722        // wait for all tasks to complete
723        for handle in handles {
724            handle.await.unwrap();
725        }
726
727        // final progress should be from the last update (0.9)
728        assert!(
729            matches!(token.state(), Progress::Determinate(p) if (p - 0.9).abs() < f64::EPSILON)
730        );
731    }
732
733    #[tokio::test]
734    async fn test_edge_cases() {
735        // single node tree
736        let token: ProgressToken<String> = ProgressToken::new("single".to_string());
737        token.update_progress(0.5);
738        assert!(
739            matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
740        );
741
742        // deep hierarchy
743        let mut current: ProgressToken<String> = ProgressToken::new("root".to_string());
744        for i in 0..10 {
745            current = current.child(1.0, format!("child{}", i));
746        }
747
748        // update leaf node
749        current.update_progress(1.0);
750        // progress should propagate to root
751        assert!(
752            matches!(current.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
753        );
754    }
755
756    #[tokio::test]
757    async fn test_three_level_hierarchy_progress() {
758        // create a three-level hierarchy with weighted progress
759        let root: ProgressToken<String> = ProgressToken::new("root".to_string());
760
761        let child1 = root.child(0.7, "child1".to_string());
762        let child2 = root.child(0.3, "child2".to_string());
763
764        let grandchild1_1 = child1.child(0.6, "grandchild1_1".to_string());
765        let grandchild1_2 = child1.child(0.4, "grandchild1_2".to_string());
766        let grandchild2_1 = child2.child(1.0, "grandchild2_1".to_string());
767
768        // update progress of grandchildren
769        grandchild1_1.update_progress(0.5); // contributes: 0.5 * 0.6 * 0.7 = 0.21 to root
770        grandchild1_2.update_progress(1.0); // contributes: 1.0 * 0.4 * 0.7 = 0.28 to root
771        grandchild2_1.update_progress(0.6); // contributes: 0.6 * 1.0 * 0.3 = 0.18 to root
772
773        // child1's progress should be: (0.5 * 0.6) + (1.0 * 0.4) = 0.7
774        assert!(
775            matches!(child1.state(), Progress::Determinate(p) if (p - 0.7).abs() < f64::EPSILON),
776            "child1 progress incorrect"
777        );
778
779        // child2's progress should be: 0.6 * 1.0 = 0.6
780        assert!(
781            matches!(child2.state(), Progress::Determinate(p) if (p - 0.6).abs() < f64::EPSILON),
782            "child2 progress incorrect"
783        );
784
785        // root's total progress should be: (0.7 * 0.7) + (0.6 * 0.3) = 0.67
786        assert!(
787            matches!(root.state(), Progress::Determinate(p) if (p - 0.67).abs() < f64::EPSILON),
788            "root progress incorrect"
789        );
790    }
791
792    #[tokio::test]
793    async fn test_completion_hierarchy() {
794        let root: ProgressToken<String> = ProgressToken::new("root".to_string());
795        let child1 = root.child(0.6, "child1".to_string());
796        let child2 = root.child(0.4, "child2".to_string());
797        let grandchild1 = child1.child(1.0, "grandchild1".to_string());
798
799        // update and complete grandchild
800        grandchild1.update_progress(0.5);
801        grandchild1.complete();
802
803        // grandchild should be at 100%
804        assert!(
805            matches!(grandchild1.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON),
806            "completed grandchild should be at 100%"
807        );
808
809        // child1's progress should reflect completed grandchild (100%)
810        assert!(
811            matches!(child1.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON),
812            "child1 progress should reflect completed grandchild"
813        );
814
815        // update other child
816        child2.update_progress(0.5);
817
818        // root's progress should reflect one completed child and one at 50%
819        // (1.0 * 0.6) + (0.5 * 0.4) = 0.8
820        assert!(
821            matches!(root.state(), Progress::Determinate(p) if (p - 0.8).abs() < f64::EPSILON),
822            "root progress incorrect after child completion"
823        );
824
825        // completing a child should not auto-complete its parent
826        child2.complete();
827        assert!(
828            matches!(root.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON),
829            "root progress should be 100% when all children complete"
830        );
831
832        // verify root is not marked as completed
833        let mut root_inner = root.node.inner.lock().unwrap();
834        assert!(
835            !root_inner.is_completed,
836            "root should not be auto-completed when children complete"
837        );
838    }
839
840    #[tokio::test]
841    async fn test_mixed_completion_states() {
842        let root: ProgressToken<String> = ProgressToken::new("root".to_string());
843        let child1 = root.child(0.5, "child1".to_string());
844        let child2 = root.child(0.5, "child2".to_string());
845
846        let grandchild1_1 = child1.child(0.7, "grandchild1_1".to_string());
847        let grandchild1_2 = child1.child(0.3, "grandchild1_2".to_string());
848
849        // complete one grandchild but leave other incomplete
850        grandchild1_1.complete();
851        grandchild1_2.update_progress(0.5);
852
853        // child1's progress should be: (1.0 * 0.7) + (0.5 * 0.3) = 0.85
854        assert!(
855            matches!(child1.state(), Progress::Determinate(p) if (p - 0.85).abs() < f64::EPSILON),
856            "child1 progress incorrect with mixed completion"
857        );
858
859        // update other child
860        child2.update_progress(0.4);
861
862        // root's progress should be: (0.85 * 0.5) + (0.4 * 0.5) = 0.625
863        assert!(
864            matches!(root.state(), Progress::Determinate(p) if (p - 0.625).abs() < f64::EPSILON),
865            "root progress incorrect with mixed completion states"
866        );
867
868        // complete remaining nodes
869        grandchild1_2.complete();
870        child2.complete();
871
872        // verify final state
873        assert!(
874            matches!(root.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON),
875            "root progress should be 100% when all descendants complete"
876        );
877        assert!(
878            matches!(child1.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON),
879            "child1 progress should be 100% when all grandchildren complete"
880        );
881    }
882
883    #[tokio::test]
884    async fn test_status_propagation() {
885        let root: ProgressToken<String> = ProgressToken::new("root".to_string());
886        let child1 = root.child(0.6, "child1".to_string());
887        let child2 = root.child(0.4, "child2".to_string());
888        let grandchild1 = child1.child(1.0, "grandchild1".to_string());
889
890        // initial status hierarchy
891        let statuses = root.statuses();
892        assert_eq!(
893            statuses,
894            vec![
895                "root".to_string(),
896                "child1".to_string(),
897                "grandchild1".to_string()
898            ]
899        );
900
901        // update grandchild status
902        grandchild1.update_status("updated grandchild".to_string());
903        let statuses = root.statuses();
904        assert_eq!(
905            statuses,
906            vec![
907                "root".to_string(),
908                "child1".to_string(),
909                "updated grandchild".to_string()
910            ]
911        );
912
913        // update child status
914        child1.update_status("updated child1".to_string());
915        let statuses = root.statuses();
916        assert_eq!(
917            statuses,
918            vec![
919                "root".to_string(),
920                "updated child1".to_string(),
921                "updated grandchild".to_string()
922            ]
923        );
924
925        // update root status
926        root.update_status("updated root".to_string());
927        let statuses = root.statuses();
928        assert_eq!(
929            statuses,
930            vec![
931                "updated root".to_string(),
932                "updated child1".to_string(),
933                "updated grandchild".to_string()
934            ]
935        );
936    }
937
938    #[tokio::test]
939    async fn test_status_propagation_with_multiple_children() {
940        let root: ProgressToken<String> = ProgressToken::new("root".to_string());
941        let child1 = root.child(0.5, "child1".to_string());
942        let child2 = root.child(0.5, "child2".to_string());
943
944        let grandchild1_1 = child1.child(0.7, "grandchild1_1".to_string());
945        let grandchild1_2 = child1.child(0.3, "grandchild1_2".to_string());
946        let grandchild2_1 = child2.child(1.0, "grandchild2_1".to_string());
947
948        // initial status hierarchy should show active path
949        let statuses = root.statuses();
950        assert_eq!(
951            statuses,
952            vec![
953                "root".to_string(),
954                "child1".to_string(),
955                "grandchild1_1".to_string()
956            ]
957        );
958
959        // update status of inactive grandchild
960        grandchild1_2.update_status("updated grandchild1_2".to_string());
961        let statuses = root.statuses();
962        assert_eq!(
963            statuses,
964            vec![
965                "root".to_string(),
966                "child1".to_string(),
967                "grandchild1_1".to_string()
968            ]
969        );
970
971        // update status of active grandchild
972        grandchild1_1.update_status("updated grandchild1_1".to_string());
973        let statuses = root.statuses();
974        assert_eq!(
975            statuses,
976            vec![
977                "root".to_string(),
978                "child1".to_string(),
979                "updated grandchild1_1".to_string()
980            ]
981        );
982
983        // update status of other branch's grandchild
984        grandchild2_1.update_status("updated grandchild2_1".to_string());
985        let statuses = root.statuses();
986        assert_eq!(
987            statuses,
988            vec![
989                "root".to_string(),
990                "child1".to_string(),
991                "updated grandchild1_1".to_string()
992            ]
993        );
994
995        // update status of other branch's child
996        child2.update_status("updated child2".to_string());
997        let statuses = root.statuses();
998        assert_eq!(
999            statuses,
1000            vec![
1001                "root".to_string(),
1002                "child1".to_string(),
1003                "updated grandchild1_1".to_string()
1004            ]
1005        );
1006    }
1007
1008    #[tokio::test]
1009    async fn test_status_propagation_with_completion() {
1010        let root: ProgressToken<String> = ProgressToken::new("root".to_string());
1011        let child1 = root.child(0.6, "child1".to_string());
1012        let child2 = root.child(0.4, "child2".to_string());
1013        let grandchild1 = child1.child(1.0, "grandchild1".to_string());
1014
1015        // initial status hierarchy
1016        let statuses = root.statuses();
1017        assert_eq!(
1018            statuses,
1019            vec![
1020                "root".to_string(),
1021                "child1".to_string(),
1022                "grandchild1".to_string()
1023            ]
1024        );
1025
1026        // update grandchild status and complete it
1027        grandchild1.update_status("completed grandchild".to_string());
1028        grandchild1.complete();
1029        let statuses = root.statuses();
1030        assert_eq!(statuses, vec!["root".to_string(), "child1".to_string()]);
1031
1032        // update child status and complete it
1033        child1.update_status("completed child1".to_string());
1034        child1.complete();
1035        let statuses = root.statuses();
1036        assert_eq!(statuses, vec!["root".to_string(), "child2".to_string()]);
1037
1038        // update remaining child status
1039        child2.update_status("updated child2".to_string());
1040        let statuses = root.statuses();
1041        assert_eq!(
1042            statuses,
1043            vec!["root".to_string(), "updated child2".to_string()]
1044        );
1045    }
1046}