1#![doc = include_str!("../README.md")]
2
3use futures::{Stream, ready};
4use pin_project_lite::pin_project;
5use std::pin::Pin;
6use std::sync::{Arc, Mutex};
7use std::task::{Context, Poll};
8use thiserror::Error;
9use tokio::sync::broadcast;
10use tokio_stream::wrappers::BroadcastStream;
11use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};
12
13#[must_use = "if unused, the progress token will be completed immediately"]
15pub struct CompleteGuard<'a, S: Clone + Send + 'static> {
16 token: &'a ProgressToken<S>,
17}
18
19impl<'a, S: Clone + Send + 'static> CompleteGuard<'a, S> {
20 pub fn forget(self) {
22 std::mem::forget(self);
23 }
24}
25
26impl<'a, S: Clone + Send + 'static> Drop for CompleteGuard<'a, S> {
27 fn drop(&mut self) {
28 self.token.complete();
29 }
30}
31
32#[derive(Debug, Clone, Copy)]
34#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
35pub enum Progress {
36 Determinate(f64),
37 Indeterminate,
38}
39
40impl Progress {
41 pub fn as_f64(&self) -> Option<f64> {
42 match self {
43 Progress::Determinate(v) => Some(*v),
44 Progress::Indeterminate => None,
45 }
46 }
47}
48
49#[derive(Debug, Clone, Copy, Error)]
50pub enum ProgressError {
51 #[error("progress updates lagged")]
54 Lagged,
55 #[error("the operation has been cancelled")]
57 Cancelled,
58}
59
60#[derive(Debug, Clone)]
62#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
63pub struct ProgressUpdate<S> {
64 pub progress: Progress,
65 pub statuses: Vec<S>,
66 pub is_cancelled: bool,
67}
68
69impl<S> ProgressUpdate<S> {
70 pub fn status(&self) -> &S {
71 self.statuses.last().unwrap()
72 }
73}
74
75struct ProgressNodeInner<S> {
77 parent: Option<Arc<ProgressNode<S>>>,
79 children: Vec<(Arc<ProgressNode<S>>, f64)>, progress: Progress,
83 status: S,
84 is_completed: bool,
85
86 update_sender: broadcast::Sender<ProgressUpdate<S>>,
88}
89
90struct ProgressNode<S> {
92 inner: Mutex<ProgressNodeInner<S>>,
93 }
95
96impl<S: Clone + Send> ProgressNode<S> {
97 fn new(status: S) -> Self {
98 let (tx, _) = broadcast::channel(16);
100
101 Self {
102 inner: Mutex::new(ProgressNodeInner {
103 parent: None,
104 children: Vec::new(),
105 progress: Progress::Determinate(0.0),
106 status,
107 is_completed: false,
108 update_sender: tx,
109 }),
110 }
112 }
113
114 fn child(parent: &Arc<Self>, weight: f64, status: S) -> Arc<Self> {
115 let mut parent_inner = parent.inner.lock().unwrap();
116
117 let (tx, _) = broadcast::channel(16);
119
120 let child = Self {
121 inner: Mutex::new(ProgressNodeInner {
122 parent: Some(parent.clone()),
123 children: Vec::new(),
124 progress: Progress::Determinate(0.0),
125 status,
126 is_completed: false,
127 update_sender: tx,
128 }),
129 };
131
132 let child = Arc::new(child);
133
134 parent_inner.children.push((child.clone(), weight));
135
136 child
137 }
138
139 fn calculate_progress(node: &Arc<Self>) -> Progress {
140 let inner = node.inner.lock().unwrap();
141
142 if matches!(inner.progress, Progress::Indeterminate) {
144 return Progress::Indeterminate;
145 }
146
147 if inner.children.is_empty() {
148 return inner.progress;
149 }
150
151 let has_indeterminate = inner
153 .children
154 .iter()
155 .filter(|(child, _)| {
156 let child_inner = child.inner.lock().unwrap();
157 !child_inner.is_completed
158 })
159 .any(|(child, _)| matches!(Self::calculate_progress(child), Progress::Indeterminate));
160
161 if has_indeterminate {
162 return Progress::Indeterminate;
163 }
164
165 let total: f64 = inner
167 .children
168 .iter()
169 .map(|(child, weight)| {
170 match Self::calculate_progress(child) {
171 Progress::Determinate(p) => p * weight,
172 Progress::Indeterminate => 0.0, }
174 })
175 .sum();
176
177 Progress::Determinate(total)
178 }
179
180 fn get_status_hierarchy(node: &Arc<Self>) -> Vec<S> {
181 let inner = node.inner.lock().unwrap();
182 let mut result = vec![inner.status.clone()];
183
184 if !inner.children.is_empty() {
186 let active_child = inner
187 .children
188 .iter()
189 .filter(|(child, _)| {
190 let child_inner = child.inner.lock().unwrap();
191 !child_inner.is_completed
192 })
193 .next();
194
195 if let Some((child, _)) = active_child {
196 let child_statuses = Self::get_status_hierarchy(child);
197 result.extend(child_statuses);
198 }
199 }
200
201 result
202 }
203
204 fn notify_subscribers(node: &Arc<Self>, is_cancelled: bool) {
205 let update = ProgressUpdate {
207 progress: Self::calculate_progress(node),
208 statuses: Self::get_status_hierarchy(node),
209 is_cancelled,
210 };
211
212 {
214 let inner = node.inner.lock().unwrap();
215 let _ = inner.update_sender.send(update);
217 };
218
219 let parent = {
224 let inner = node.inner.lock().unwrap();
225 inner.parent.clone()
226 };
227
228 if let Some(parent) = parent {
229 Self::notify_subscribers(&parent, false);
230 }
231 }
232}
233
234#[derive(Clone)]
236pub struct ProgressToken<S> {
237 node: Arc<ProgressNode<S>>,
238 cancel_token: CancellationToken,
239}
240
241impl<S: std::fmt::Debug> std::fmt::Debug for ProgressToken<S> {
242 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243 f.debug_struct("ProgressToken")
244 .field("is_cancelled", &self.cancel_token.is_cancelled())
245 .finish()
246 }
247}
248
249impl<S: Clone + Send + 'static> ProgressToken<S> {
250 pub fn new(status: impl Into<S>) -> Self {
252 let node = Arc::new(ProgressNode::new(status.into()));
253
254 Self {
255 node,
256 cancel_token: CancellationToken::new(),
257 }
258 }
259
260 pub fn child(&self, weight: f64, status: impl Into<S>) -> Self {
262 let node = ProgressNode::child(&self.node, weight, status.into());
263
264 Self {
265 node,
266 cancel_token: self.cancel_token.child_token(),
267 }
268 }
269
270 pub fn update_progress(&self, progress: f64) {
272 if self.is_cancelled() {
273 return;
274 }
275
276 let is_completed = {
277 let inner = self.node.inner.lock().unwrap();
278 inner.is_completed
279 };
280
281 if is_completed {
282 return;
283 }
284
285 let mut inner = self.node.inner.lock().unwrap();
286 inner.progress = Progress::Determinate(progress.max(0.0).min(1.0));
287 drop(inner);
288
289 ProgressNode::notify_subscribers(&self.node, false);
290 }
291
292 pub fn update_indeterminate(&self) {
294 if self.is_cancelled() {
295 return;
296 }
297
298 let mut inner = self.node.inner.lock().unwrap();
299 if inner.is_completed {
300 return;
301 }
302
303 inner.progress = Progress::Indeterminate;
304 drop(inner);
305
306 ProgressNode::notify_subscribers(&self.node, false);
307 }
308
309 pub fn update_status(&self, status: impl Into<S>) {
311 if self.is_cancelled() {
312 return;
313 }
314
315 let mut inner = self.node.inner.lock().unwrap();
316 if inner.is_completed {
317 return;
318 }
319
320 inner.status = status.into();
321 drop(inner);
322
323 ProgressNode::notify_subscribers(&self.node, false);
324 }
325
326 pub fn update(&self, progress: Progress, status: impl Into<S>) {
328 if self.is_cancelled() {
329 return;
330 }
331
332 let mut inner = self.node.inner.lock().unwrap();
333 if inner.is_completed {
334 return;
335 }
336
337 inner.status = status.into();
338 inner.progress = progress;
339 drop(inner);
340
341 ProgressNode::notify_subscribers(&self.node, false);
342 }
343
344 pub fn complete(&self) {
346 if self.is_cancelled() {
347 return;
348 }
349
350 let mut inner = self.node.inner.lock().unwrap();
351 if !inner.is_completed {
352 inner.is_completed = true;
353 inner.progress = Progress::Determinate(1.0);
354 drop(inner);
355
356 ProgressNode::notify_subscribers(&self.node, false);
357 }
358 }
359
360 pub fn check(&self) -> Result<(), ProgressError> {
362 if self.is_cancelled() {
363 Err(ProgressError::Cancelled)
364 } else {
365 Ok(())
366 }
367 }
368
369 pub fn is_cancelled(&self) -> bool {
370 self.cancel_token.is_cancelled()
371 }
372
373 pub fn cancel(&self) {
375 if !self.cancel_token.is_cancelled() {
376 self.cancel_token.cancel();
377 ProgressNode::notify_subscribers(&self.node, true);
378 }
379 }
380
381 pub fn state(&self) -> Progress {
383 ProgressNode::calculate_progress(&self.node)
384 }
385
386 pub fn statuses(&self) -> Vec<S> {
388 ProgressNode::get_status_hierarchy(&self.node)
389 }
390
391 pub fn cancelled(&self) -> WaitForCancellationFuture {
392 self.cancel_token.cancelled()
393 }
394
395 pub async fn updated(&self) -> Result<ProgressUpdate<S>, ProgressError> {
396 let mut rx = {
397 let inner = self.node.inner.lock().unwrap();
398 inner.update_sender.subscribe()
399 };
400
401 tokio::select! {
402 _ = self.cancel_token.cancelled() => {
403 Err(ProgressError::Cancelled)
404 }
405 result = rx.recv() => {
406 match result {
407 Ok(update) => Ok(update),
408 Err(broadcast::error::RecvError::Closed) => Err(ProgressError::Cancelled),
409 Err(broadcast::error::RecvError::Lagged(_)) => Err(ProgressError::Lagged),
410 }
411 }
412 }
413 }
414
415 pub fn subscribe(&self) -> ProgressStream<'_, S> {
417 let rx = {
418 let inner = self.node.inner.lock().unwrap();
419 inner.update_sender.subscribe()
420 };
421
422 ProgressStream {
423 token: self,
424 rx: BroadcastStream::new(rx),
425 }
426 }
427
428 pub fn complete_guard(&self) -> CompleteGuard<'_, S> {
430 CompleteGuard { token: self }
431 }
432}
433
434pin_project! {
435 #[must_use = "futures do nothing unless polled"]
438 pub struct WaitForUpdateFuture<'a, S> {
439 token: &'a ProgressToken<S>,
440 #[pin]
441 future: tokio::sync::futures::Notified<'a>,
442 }
443}
444
445impl<'a, S: Clone + Send + 'static> Future for WaitForUpdateFuture<'a, S> {
446 type Output = Option<ProgressUpdate<S>>;
447
448 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
449 let mut this = self.project();
450 if this.token.cancel_token.is_cancelled() {
451 return Poll::Ready(None);
452 }
453
454 ready!(this.future.as_mut().poll(cx));
455
456 Poll::Ready(Some(ProgressUpdate {
457 progress: this.token.state(),
458 statuses: this.token.statuses(),
459 is_cancelled: false,
460 }))
461 }
462}
463
464pin_project! {
465 #[must_use = "streams do nothing unless polled"]
467 pub struct ProgressStream<'a, S> {
468 token: &'a ProgressToken<S>,
469 #[pin]
470 rx: BroadcastStream<ProgressUpdate<S>>,
471 }
472}
473
474impl<'a, S: Clone + Send + 'static> Stream for ProgressStream<'a, S> {
475 type Item = ProgressUpdate<S>;
476
477 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
478 self.project()
479 .rx
480 .poll_next(cx)
481 .map(|opt| opt.map(|res| res.unwrap()))
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use futures::StreamExt;
489 use std::time::Duration;
490 use tokio::time::sleep;
491
492 async fn create_test_hierarchy() -> (
494 ProgressToken<String>,
495 ProgressToken<String>,
496 ProgressToken<String>,
497 ) {
498 let root = ProgressToken::new("root".to_string());
499 let child1 = root.child(0.6, "child1".to_string());
500 let child2 = root.child(0.4, "child2".to_string());
501 (root, child1, child2)
502 }
503
504 #[tokio::test]
505 async fn test_basic_progress_updates() {
506 let token: ProgressToken<String> = ProgressToken::new("test".to_string());
507 token.update_progress(0.5);
508 assert!(
509 matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
510 );
511
512 token.update_progress(1.0);
513 assert!(
514 matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
515 );
516
517 token.update_progress(1.5);
519 assert!(
520 matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
521 );
522
523 token.update_progress(-0.5);
524 assert!(matches!(token.state(), Progress::Determinate(p) if p.abs() < f64::EPSILON));
525 }
526
527 #[tokio::test]
528 async fn test_hierarchical_progress() {
529 let (root, child1, child2) = create_test_hierarchy().await;
530
531 child1.update_progress(0.5);
533 child2.update_progress(0.5);
534
535 assert!(matches!(root.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON));
537
538 child1.update_progress(1.0);
539 assert!(matches!(root.state(), Progress::Determinate(p) if (p - 0.8).abs() < f64::EPSILON));
541 }
542
543 #[tokio::test]
544 async fn test_indeterminate_state() {
545 let (root, child1, child2) = create_test_hierarchy().await;
546
547 child1.update_indeterminate();
549 child2.update_progress(0.5);
550
551 assert!(matches!(root.state(), Progress::Indeterminate));
553
554 child1.update_progress(0.5);
556 assert!(matches!(root.state(), Progress::Determinate(_)));
557 }
558
559 #[tokio::test]
560 async fn test_status_updates() {
561 let token: ProgressToken<String> = ProgressToken::new("initial status".to_string());
562 let statuses = token.statuses();
563 assert_eq!(statuses, vec!["initial status".to_string()]);
564
565 token.update_status("updated status".to_string());
566 let statuses = token.statuses();
567 assert_eq!(statuses, vec!["updated status".to_string()]);
568 }
569
570 #[tokio::test]
571 async fn test_status_hierarchy() {
572 let (root, child1, _) = create_test_hierarchy().await;
573
574 let statuses = root.statuses();
575 assert_eq!(statuses, vec!["root".to_string(), "child1".to_string()]);
576
577 child1.update_status("updated child1".to_string());
578 let statuses = root.statuses();
579 assert_eq!(
580 statuses,
581 vec!["root".to_string(), "updated child1".to_string()]
582 );
583 }
584
585 #[tokio::test]
586 async fn test_cancellation() {
587 let (root, child1, child2) = create_test_hierarchy().await;
588
589 root.cancel();
591
592 assert!(root.cancel_token.is_cancelled());
593 assert!(child1.cancel_token.is_cancelled());
594 assert!(child2.cancel_token.is_cancelled());
595
596 child1.update_progress(0.5);
598 assert!(matches!(child1.state(), Progress::Determinate(p) if p.abs() < f64::EPSILON));
599 }
600
601 #[tokio::test]
602 async fn test_complete_guard() {
603 let token: ProgressToken<String> = ProgressToken::new("test".to_string());
604
605 {
606 let _guard = token.complete_guard();
607 token.update_progress(0.5);
608 assert!(
609 matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
610 );
611 } assert!(
615 matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
616 );
617
618 token.update_progress(0.5);
620 assert!(
621 matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
622 );
623
624 let token: ProgressToken<String> = ProgressToken::new("test2".to_string());
626 {
627 let guard = token.complete_guard();
628 token.update_progress(0.5);
629 guard.forget(); }
631
632 assert!(
634 matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
635 );
636 }
637
638 #[tokio::test]
639 async fn test_subscription() {
640 let token: ProgressToken<String> = ProgressToken::new("test".to_string());
641 let mut subscription = token.subscribe();
642
643 let update = subscription.next().await.unwrap();
645 assert_eq!(update.status(), &"test".to_string());
646 assert!(matches!(update.progress, Progress::Determinate(p) if p.abs() < f64::EPSILON));
647
648 token.update_progress(0.5);
650 let update = subscription.next().await.unwrap();
651 assert!(
652 matches!(update.progress, Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
653 );
654 }
655
656 #[tokio::test]
657 async fn test_multiple_subscribers() {
658 let token: ProgressToken<String> = ProgressToken::new("test".to_string());
659 let mut sub1 = token.subscribe();
660 let mut sub2 = token.subscribe();
661
662 sub1.next().await.unwrap();
664 sub2.next().await.unwrap();
665
666 token.update_progress(0.5);
668
669 let update1 = sub1.next().await.unwrap();
670 let update2 = sub2.next().await.unwrap();
671
672 assert!(
673 matches!(update1.progress, Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON),
674 "{update1:?}"
675 );
676 assert!(
677 matches!(update2.progress, Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON),
678 "{update2:?}"
679 );
680 }
681
682 #[tokio::test]
683 async fn test_concurrent_updates() {
684 let token: ProgressToken<String> = ProgressToken::new("test".to_string());
685 let mut handles = vec![];
686
687 for i in 0..10 {
689 let token = token.clone();
690 handles.push(tokio::spawn(async move {
691 sleep(Duration::from_millis(i * 10)).await;
692 token.update_progress(i as f64 / 10.0);
693 }));
694 }
695
696 for handle in handles {
698 handle.await.unwrap();
699 }
700
701 assert!(
703 matches!(token.state(), Progress::Determinate(p) if (p - 0.9).abs() < f64::EPSILON)
704 );
705 }
706
707 #[tokio::test]
708 async fn test_edge_cases() {
709 let token: ProgressToken<String> = ProgressToken::new("single".to_string());
711 token.update_progress(0.5);
712 assert!(
713 matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
714 );
715
716 let mut current: ProgressToken<String> = ProgressToken::new("root".to_string());
718 for i in 0..10 {
719 current = current.child(1.0, format!("child{}", i));
720 }
721
722 current.update_progress(1.0);
724 assert!(
726 matches!(current.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
727 );
728 }
729}