1use futures::{Stream, ready};
112use pin_project_lite::pin_project;
113use std::pin::Pin;
114use std::sync::atomic::{AtomicBool, Ordering};
115use std::sync::{Arc, Mutex};
116use std::task::{Context, Poll};
117use tokio::sync::broadcast;
118use tokio_stream::wrappers::BroadcastStream;
119use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};
120
121#[derive(Debug, Clone, Copy)]
123pub enum Progress {
124 Determinate(f64),
125 Indeterminate,
126}
127
128impl Progress {
129 pub fn as_f64(&self) -> Option<f64> {
130 match self {
131 Progress::Determinate(v) => Some(*v),
132 Progress::Indeterminate => None,
133 }
134 }
135}
136
137#[derive(Debug, Clone, Copy)]
138pub enum ProgressError {
139 Lagged,
142 Cancelled,
144}
145
146#[derive(Debug, Clone)]
148pub struct ProgressUpdate<S> {
149 pub progress: Progress,
150 pub statuses: Vec<S>,
151 pub is_cancelled: bool,
152}
153
154impl<S> ProgressUpdate<S> {
155 pub fn status(&self) -> &S {
156 self.statuses.last().unwrap()
157 }
158}
159
160struct ProgressNodeInner<S> {
162 parent: Option<Arc<ProgressNode<S>>>,
164 children: Vec<(Arc<ProgressNode<S>>, f64)>, progress: Progress,
168 status: S,
169 is_completed: bool,
170
171 update_sender: broadcast::Sender<ProgressUpdate<S>>,
173}
174
175struct ProgressNode<S> {
177 inner: Mutex<ProgressNodeInner<S>>,
178 }
180
181impl<S: Clone + Send> ProgressNode<S> {
182 fn new(status: S) -> Self {
183 let (tx, _) = broadcast::channel(16);
185
186 Self {
187 inner: Mutex::new(ProgressNodeInner {
188 parent: None,
189 children: Vec::new(),
190 progress: Progress::Determinate(0.0),
191 status,
192 is_completed: false,
193 update_sender: tx,
194 }),
195 }
197 }
198
199 fn child(parent: &Arc<Self>, weight: f64, status: S) -> Arc<Self> {
200 let mut parent_inner = parent.inner.lock().unwrap();
201
202 let (tx, _) = broadcast::channel(16);
204
205 let child = Self {
206 inner: Mutex::new(ProgressNodeInner {
207 parent: Some(parent.clone()),
208 children: Vec::new(),
209 progress: Progress::Determinate(0.0),
210 status,
211 is_completed: false,
212 update_sender: tx,
213 }),
214 };
216
217 let child = Arc::new(child);
218
219 parent_inner.children.push((child.clone(), weight));
220
221 child
222 }
223
224 fn calculate_progress(node: &Arc<Self>) -> Progress {
225 let inner = node.inner.lock().unwrap();
226
227 if matches!(inner.progress, Progress::Indeterminate) {
229 return Progress::Indeterminate;
230 }
231
232 if inner.children.is_empty() {
233 return inner.progress;
234 }
235
236 let has_indeterminate = inner
238 .children
239 .iter()
240 .filter(|(child, _)| {
241 let child_inner = child.inner.lock().unwrap();
242 !child_inner.is_completed
243 })
244 .any(|(child, _)| matches!(Self::calculate_progress(child), Progress::Indeterminate));
245
246 if has_indeterminate {
247 return Progress::Indeterminate;
248 }
249
250 let total: f64 = inner
252 .children
253 .iter()
254 .map(|(child, weight)| {
255 match Self::calculate_progress(child) {
256 Progress::Determinate(p) => p * weight,
257 Progress::Indeterminate => 0.0, }
259 })
260 .sum();
261
262 Progress::Determinate(total)
263 }
264
265 fn get_status_hierarchy(node: &Arc<Self>) -> Vec<S> {
266 let inner = node.inner.lock().unwrap();
267 let mut result = vec![inner.status.clone()];
268
269 if !inner.children.is_empty() {
271 let active_child = inner
272 .children
273 .iter()
274 .filter(|(child, _)| {
275 let child_inner = child.inner.lock().unwrap();
276 !child_inner.is_completed
277 })
278 .next();
279
280 if let Some((child, _)) = active_child {
281 let child_statuses = Self::get_status_hierarchy(child);
282 result.extend(child_statuses);
283 }
284 }
285
286 result
287 }
288
289 fn notify_subscribers(node: &Arc<Self>, is_cancelled: bool) {
290 let update = ProgressUpdate {
292 progress: Self::calculate_progress(node),
293 statuses: Self::get_status_hierarchy(node),
294 is_cancelled,
295 };
296
297 {
299 let inner = node.inner.lock().unwrap();
300 let _ = inner.update_sender.send(update);
302 };
303
304 let parent = {
309 let inner = node.inner.lock().unwrap();
310 inner.parent.clone()
311 };
312
313 if let Some(parent) = parent {
314 Self::notify_subscribers(&parent, false);
315 }
316 }
317}
318
319#[derive(Clone)]
321pub struct ProgressToken<S> {
322 node: Arc<ProgressNode<S>>,
323 is_active: Arc<AtomicBool>,
324 cancel_token: CancellationToken,
325}
326
327impl<S: Clone + Send + 'static> ProgressToken<S> {
328 pub fn new(status: impl Into<S>) -> Arc<Self> {
330 let node = Arc::new(ProgressNode::new(status.into()));
331
332 Arc::new(Self {
333 node,
334 is_active: Arc::new(AtomicBool::new(true)),
335 cancel_token: CancellationToken::new(),
336 })
337 }
338
339 pub fn child(parent: &Arc<Self>, weight: f64, status: impl Into<S>) -> Arc<Self> {
341 let node = ProgressNode::child(&parent.node, weight, status.into());
342
343 Arc::new(Self {
344 node,
345 is_active: Arc::new(AtomicBool::new(true)),
346 cancel_token: parent.cancel_token.child_token(),
347 })
348 }
349
350 pub fn progress(&self, progress: f64) {
352 if !self.is_active.load(Ordering::Relaxed) || self.cancel_token.is_cancelled() {
353 return;
354 }
355
356 let mut inner = self.node.inner.lock().unwrap();
362 inner.progress = Progress::Determinate(progress.max(0.0).min(1.0));
363 drop(inner);
364
365 ProgressNode::notify_subscribers(&self.node, false);
366 }
368
369 pub fn indeterminate(&self) {
371 if !self.is_active.load(Ordering::Relaxed) || self.cancel_token.is_cancelled() {
372 return;
373 }
374
375 let mut inner = self.node.inner.lock().unwrap();
376 inner.progress = Progress::Indeterminate;
377 drop(inner);
378
379 ProgressNode::notify_subscribers(&self.node, false);
380 }
381
382 pub fn status(&self, status: impl Into<S>) {
384 if !self.is_active.load(Ordering::Relaxed) || self.cancel_token.is_cancelled() {
385 return;
386 }
387
388 let mut inner = self.node.inner.lock().unwrap();
389 inner.status = status.into();
390 drop(inner);
391
392 ProgressNode::notify_subscribers(&self.node, false);
393 }
394
395 pub fn complete(&self) {
397 if self.is_active.swap(false, Ordering::Relaxed) {
398 let mut inner = self.node.inner.lock().unwrap();
399 inner.is_completed = true;
400 inner.progress = Progress::Determinate(1.0);
401 drop(inner);
402
403 ProgressNode::notify_subscribers(&self.node, false);
404 }
405 }
406
407 pub fn cancel(&self) {
409 if self.is_active.swap(false, Ordering::Relaxed) {
410 self.cancel_token.cancel();
411
412 ProgressNode::notify_subscribers(&self.node, true);
413 }
414 }
415
416 pub fn is_cancelled(&self) -> bool {
418 self.cancel_token.is_cancelled()
419 }
420
421 pub fn state(&self) -> Progress {
423 ProgressNode::calculate_progress(&self.node)
424 }
425
426 pub fn statuses(&self) -> Vec<S> {
428 ProgressNode::get_status_hierarchy(&self.node)
429 }
430
431 pub fn cancelled(&self) -> WaitForCancellationFuture {
432 self.cancel_token.cancelled()
433 }
434
435 pub async fn updated(&self) -> Result<ProgressUpdate<S>, ProgressError> {
436 let mut rx = {
437 let inner = self.node.inner.lock().unwrap();
438 inner.update_sender.subscribe()
439 };
440
441 tokio::select! {
442 _ = self.cancel_token.cancelled() => {
443 Err(ProgressError::Cancelled)
444 }
445 result = rx.recv() => {
446 match result {
447 Ok(update) => Ok(update),
448 Err(broadcast::error::RecvError::Closed) => Err(ProgressError::Cancelled),
449 Err(broadcast::error::RecvError::Lagged(_)) => Err(ProgressError::Lagged),
450 }
451 }
452 }
453 }
454
455 pub fn subscribe(&self) -> ProgressStream<'_, S> {
457 let rx = {
458 let inner = self.node.inner.lock().unwrap();
459 inner.update_sender.subscribe()
460 };
461
462 ProgressStream {
463 token: self,
464 rx: BroadcastStream::new(rx),
465 }
466 }
467}
468
469pin_project! {
470 #[must_use = "futures do nothing unless polled"]
473 pub struct WaitForUpdateFuture<'a, S> {
474 token: &'a ProgressToken<S>,
475 #[pin]
476 future: tokio::sync::futures::Notified<'a>,
477 }
478}
479
480impl<'a, S: Clone + Send + 'static> Future for WaitForUpdateFuture<'a, S> {
481 type Output = Option<ProgressUpdate<S>>;
482
483 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
484 let mut this = self.project();
485 if this.token.is_cancelled() {
486 return Poll::Ready(None);
487 }
488
489 ready!(this.future.as_mut().poll(cx));
490
491 Poll::Ready(Some(ProgressUpdate {
492 progress: this.token.state(),
493 statuses: this.token.statuses(),
494 is_cancelled: false,
495 }))
496 }
497}
498
499pin_project! {
500 #[must_use = "streams do nothing unless polled"]
502 pub struct ProgressStream<'a, S> {
503 token: &'a ProgressToken<S>,
504 #[pin]
505 rx: BroadcastStream<ProgressUpdate<S>>,
506 }
507}
508
509impl<'a, S: Clone + Send + 'static> Stream for ProgressStream<'a, S> {
510 type Item = ProgressUpdate<S>;
511
512 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
513 self.project()
514 .rx
515 .poll_next(cx)
516 .map(|opt| opt.map(|res| res.unwrap()))
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523 use futures::StreamExt;
524 use std::time::Duration;
525 use tokio::time::sleep;
526
527 async fn create_test_hierarchy() -> (
529 Arc<ProgressToken<String>>,
530 Arc<ProgressToken<String>>,
531 Arc<ProgressToken<String>>,
532 ) {
533 let root = ProgressToken::new("root".to_string());
534 let child1 = ProgressToken::child(&root, 0.6, "child1".to_string());
535 let child2 = ProgressToken::child(&root, 0.4, "child2".to_string());
536 (root, child1, child2)
537 }
538
539 #[tokio::test]
540 async fn test_basic_progress_updates() {
541 let token = ProgressToken::new("test".to_string());
542 token.progress(0.5);
543 assert!(
544 matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
545 );
546
547 token.progress(1.0);
548 assert!(
549 matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
550 );
551
552 token.progress(1.5);
554 assert!(
555 matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
556 );
557
558 token.progress(-0.5);
559 assert!(matches!(token.state(), Progress::Determinate(p) if p.abs() < f64::EPSILON));
560 }
561
562 #[tokio::test]
563 async fn test_hierarchical_progress() {
564 let (root, child1, child2) = create_test_hierarchy().await;
565
566 child1.progress(0.5);
568 child2.progress(0.5);
569
570 assert!(matches!(root.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON));
572
573 child1.progress(1.0);
574 assert!(matches!(root.state(), Progress::Determinate(p) if (p - 0.8).abs() < f64::EPSILON));
576 }
577
578 #[tokio::test]
579 async fn test_indeterminate_state() {
580 let (root, child1, child2) = create_test_hierarchy().await;
581
582 child1.indeterminate();
584 child2.progress(0.5);
585
586 assert!(matches!(root.state(), Progress::Indeterminate));
588
589 child1.progress(0.5);
591 assert!(matches!(root.state(), Progress::Determinate(_)));
592 }
593
594 #[tokio::test]
595 async fn test_status_updates() {
596 let token = ProgressToken::new("initial status".to_string());
597 let statuses = token.statuses();
598 assert_eq!(statuses, vec!["initial status".to_string()]);
599
600 token.status("updated status".to_string());
601 let statuses = token.statuses();
602 assert_eq!(statuses, vec!["updated status".to_string()]);
603 }
604
605 #[tokio::test]
606 async fn test_status_hierarchy() {
607 let (root, child1, _) = create_test_hierarchy().await;
608
609 let statuses = root.statuses();
610 assert_eq!(statuses, vec!["root".to_string(), "child1".to_string()]);
611
612 child1.status("updated child1".to_string());
613 let statuses = root.statuses();
614 assert_eq!(
615 statuses,
616 vec!["root".to_string(), "updated child1".to_string()]
617 );
618 }
619
620 #[tokio::test]
621 async fn test_cancellation() {
622 let (root, child1, child2) = create_test_hierarchy().await;
623
624 root.cancel();
626
627 assert!(root.is_cancelled());
628 assert!(child1.is_cancelled());
629 assert!(child2.is_cancelled());
630
631 child1.progress(0.5);
633 assert!(matches!(child1.state(), Progress::Determinate(p) if p.abs() < f64::EPSILON));
634 }
635
636 #[tokio::test]
637 async fn test_completion() {
638 let token = ProgressToken::new("test".to_string());
639 token.complete();
640
641 assert!(
642 matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
643 );
644
645 token.progress(0.5);
647 assert!(
648 matches!(token.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
649 );
650 }
651
652 #[tokio::test]
653 async fn test_subscription() {
654 let token = ProgressToken::new("test".to_string());
655 let mut subscription = token.subscribe();
656
657 let update = subscription.next().await.unwrap();
659 assert_eq!(update.status(), &"test".to_string());
660 assert!(matches!(update.progress, Progress::Determinate(p) if p.abs() < f64::EPSILON));
661
662 token.progress(0.5);
664 let update = subscription.next().await.unwrap();
665 assert!(
666 matches!(update.progress, Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
667 );
668 }
669
670 #[tokio::test]
671 async fn test_multiple_subscribers() {
672 let token = ProgressToken::new("test".to_string());
673 let mut sub1 = token.subscribe();
674 let mut sub2 = token.subscribe();
675
676 sub1.next().await.unwrap();
678 sub2.next().await.unwrap();
679
680 token.progress(0.5);
682
683 let update1 = sub1.next().await.unwrap();
684 let update2 = sub2.next().await.unwrap();
685
686 assert!(
687 matches!(update1.progress, Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON),
688 "{update1:?}"
689 );
690 assert!(
691 matches!(update2.progress, Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON),
692 "{update2:?}"
693 );
694 }
695
696 #[tokio::test]
697 async fn test_concurrent_updates() {
698 let token = Arc::new(ProgressToken::new("test".to_string()));
699 let mut handles = vec![];
700
701 for i in 0..10 {
703 let token = token.clone();
704 handles.push(tokio::spawn(async move {
705 sleep(Duration::from_millis(i * 10)).await;
706 token.progress(i as f64 / 10.0);
707 }));
708 }
709
710 for handle in handles {
712 handle.await.unwrap();
713 }
714
715 assert!(
717 matches!(token.state(), Progress::Determinate(p) if (p - 0.9).abs() < f64::EPSILON)
718 );
719 }
720
721 #[tokio::test]
722 async fn test_edge_cases() {
723 let token = ProgressToken::new("single".to_string());
725 token.progress(0.5);
726 assert!(
727 matches!(token.state(), Progress::Determinate(p) if (p - 0.5).abs() < f64::EPSILON)
728 );
729
730 let mut current = ProgressToken::new("root".to_string());
732 for i in 0..10 {
733 current = ProgressToken::child(¤t, 1.0, format!("child{}", i));
734 }
735
736 current.progress(1.0);
738 assert!(
740 matches!(current.state(), Progress::Determinate(p) if (p - 1.0).abs() < f64::EPSILON)
741 );
742 }
743}