Skip to main content

pipedream_rs/
tracker.rs

1use parking_lot::Mutex;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::task::{Context, Poll};
7use tokio::sync::Notify;
8
9use crate::error::RelayError;
10
11/// Tracks completion of a single message across all handlers.
12///
13/// ## Semantics
14///
15/// The `handler_count` at send time determines `expected`. This is "best effort" -
16/// handlers registered concurrently with a send may or may not be counted.
17/// Document explicitly: send awaits approximately the current handler count;
18/// handlers added concurrently may or may not be accounted for.
19pub struct CompletionTracker {
20    expected: AtomicUsize,
21    completed: AtomicUsize,
22    error: Mutex<Option<RelayError>>,
23    notify: Arc<Notify>,
24}
25
26impl CompletionTracker {
27    /// Create a new tracker expecting `expected` handlers to complete.
28    ///
29    /// Note: `expected` is a snapshot at send time. Handlers registered
30    /// concurrently may not be included in this count.
31    pub fn new(expected: usize) -> Self {
32        Self {
33            expected: AtomicUsize::new(expected),
34            completed: AtomicUsize::new(0),
35            error: Mutex::new(None),
36            notify: Arc::new(Notify::new()),
37        }
38    }
39
40    /// Get the expected count.
41    pub fn expected(&self) -> usize {
42        self.expected.load(Ordering::SeqCst)
43    }
44
45    /// Get the completed count.
46    pub fn completed(&self) -> usize {
47        self.completed.load(Ordering::SeqCst)
48    }
49
50    /// Check if all handlers have completed.
51    pub fn is_complete(&self) -> bool {
52        self.completed() >= self.expected()
53    }
54
55    /// Signal successful completion of one handler.
56    ///
57    /// Calling this more than once per handler is a logic error but is
58    /// handled gracefully (extra completions are ignored after the
59    /// expected count is reached).
60    pub fn complete_one(&self) {
61        let completed = self.completed.fetch_add(1, Ordering::SeqCst) + 1;
62        let expected = self.expected.load(Ordering::SeqCst);
63
64        // Debug assertion to catch double-completion bugs during development
65        debug_assert!(
66            completed <= expected + 1,
67            "complete_one called {} times but only {} expected (possible double-completion bug)",
68            completed,
69            expected
70        );
71
72        if completed >= expected {
73            self.notify.notify_waiters();
74        }
75    }
76
77    /// Signal failure with an error.
78    pub fn fail(&self, error: RelayError) {
79        {
80            let mut err = self.error.lock();
81            if err.is_none() {
82                *err = Some(error);
83            }
84        }
85        self.complete_one();
86    }
87
88    /// Take the error if one occurred.
89    pub fn take_error(&self) -> Option<RelayError> {
90        self.error.lock().take()
91    }
92
93    /// Returns a future that completes when all handlers have finished.
94    ///
95    /// This is the proper way to await completion - it's a pollable future
96    /// that doesn't spawn tasks from within poll().
97    pub fn wait(&self) -> CompletionFuture<'_> {
98        CompletionFuture { tracker: self }
99    }
100
101    /// Async method to wait for completion.
102    ///
103    /// Alternative to `wait()` for use in async contexts where you want to
104    /// await directly.
105    pub async fn wait_async(&self) {
106        while !self.is_complete() {
107            self.notify.notified().await;
108        }
109    }
110}
111
112impl CompletionTracker {
113    /// Returns an owned future that completes when all handlers have finished.
114    ///
115    /// This version takes `Arc<Self>` and returns a `'static` future that can
116    /// be stored in a state machine.
117    pub async fn wait_owned(self: Arc<Self>) {
118        while !self.is_complete() {
119            self.notify.notified().await;
120        }
121    }
122}
123
124/// Future returned by `CompletionTracker::wait()`.
125///
126/// Polls completion state and awaits notification without spawning tasks.
127pub struct CompletionFuture<'a> {
128    tracker: &'a CompletionTracker,
129}
130
131impl<'a> Future for CompletionFuture<'a> {
132    type Output = ();
133
134    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135        // Check if already complete
136        if self.tracker.is_complete() {
137            return Poll::Ready(());
138        }
139
140        // Register waker using enable() pattern - this is the proper way
141        // to integrate with Notify without storing futures
142        let notified = self.tracker.notify.notified();
143        // Pin it on the stack
144        futures::pin_mut!(notified);
145
146        // Enable the notified future to receive wake notifications
147        // This registers the current task's waker with the Notify
148        notified.as_mut().enable();
149
150        // Check again after registering (avoid race condition)
151        if self.tracker.is_complete() {
152            return Poll::Ready(());
153        }
154
155        // Now poll the notified future to properly register the waker
156        match notified.as_mut().poll(cx) {
157            Poll::Ready(()) => {
158                // Notification received, check completion again
159                if self.tracker.is_complete() {
160                    Poll::Ready(())
161                } else {
162                    // Not yet complete, wake self to retry
163                    cx.waker().wake_by_ref();
164                    Poll::Pending
165                }
166            }
167            Poll::Pending => Poll::Pending,
168        }
169    }
170}
171
172impl std::fmt::Debug for CompletionTracker {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        f.debug_struct("CompletionTracker")
175            .field("expected", &self.expected())
176            .field("completed", &self.completed())
177            .field("has_error", &self.error.lock().is_some())
178            .finish()
179    }
180}