pipedream_rs/tracker.rs
1use parking_lot::Mutex;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::task::{Context, Poll};
7use tokio::sync::Notify;
8
9use crate::error::RelayError;
10
11/// Tracks completion of a single message across all handlers.
12///
13/// ## Semantics
14///
15/// The `handler_count` at send time determines `expected`. This is "best effort" -
16/// handlers registered concurrently with a send may or may not be counted.
17/// Document explicitly: send awaits approximately the current handler count;
18/// handlers added concurrently may or may not be accounted for.
19pub struct CompletionTracker {
20 expected: AtomicUsize,
21 completed: AtomicUsize,
22 error: Mutex<Option<RelayError>>,
23 notify: Arc<Notify>,
24}
25
26impl CompletionTracker {
27 /// Create a new tracker expecting `expected` handlers to complete.
28 ///
29 /// Note: `expected` is a snapshot at send time. Handlers registered
30 /// concurrently may not be included in this count.
31 pub fn new(expected: usize) -> Self {
32 Self {
33 expected: AtomicUsize::new(expected),
34 completed: AtomicUsize::new(0),
35 error: Mutex::new(None),
36 notify: Arc::new(Notify::new()),
37 }
38 }
39
40 /// Get the expected count.
41 pub fn expected(&self) -> usize {
42 self.expected.load(Ordering::SeqCst)
43 }
44
45 /// Get the completed count.
46 pub fn completed(&self) -> usize {
47 self.completed.load(Ordering::SeqCst)
48 }
49
50 /// Check if all handlers have completed.
51 pub fn is_complete(&self) -> bool {
52 self.completed() >= self.expected()
53 }
54
55 /// Signal successful completion of one handler.
56 ///
57 /// Calling this more than once per handler is a logic error but is
58 /// handled gracefully (extra completions are ignored after the
59 /// expected count is reached).
60 pub fn complete_one(&self) {
61 let completed = self.completed.fetch_add(1, Ordering::SeqCst) + 1;
62 let expected = self.expected.load(Ordering::SeqCst);
63
64 // Debug assertion to catch double-completion bugs during development
65 debug_assert!(
66 completed <= expected + 1,
67 "complete_one called {} times but only {} expected (possible double-completion bug)",
68 completed,
69 expected
70 );
71
72 if completed >= expected {
73 self.notify.notify_waiters();
74 }
75 }
76
77 /// Signal failure with an error.
78 pub fn fail(&self, error: RelayError) {
79 {
80 let mut err = self.error.lock();
81 if err.is_none() {
82 *err = Some(error);
83 }
84 }
85 self.complete_one();
86 }
87
88 /// Take the error if one occurred.
89 pub fn take_error(&self) -> Option<RelayError> {
90 self.error.lock().take()
91 }
92
93 /// Returns a future that completes when all handlers have finished.
94 ///
95 /// This is the proper way to await completion - it's a pollable future
96 /// that doesn't spawn tasks from within poll().
97 pub fn wait(&self) -> CompletionFuture<'_> {
98 CompletionFuture { tracker: self }
99 }
100
101 /// Async method to wait for completion.
102 ///
103 /// Alternative to `wait()` for use in async contexts where you want to
104 /// await directly.
105 pub async fn wait_async(&self) {
106 while !self.is_complete() {
107 self.notify.notified().await;
108 }
109 }
110}
111
112impl CompletionTracker {
113 /// Returns an owned future that completes when all handlers have finished.
114 ///
115 /// This version takes `Arc<Self>` and returns a `'static` future that can
116 /// be stored in a state machine.
117 pub async fn wait_owned(self: Arc<Self>) {
118 while !self.is_complete() {
119 self.notify.notified().await;
120 }
121 }
122}
123
124/// Future returned by `CompletionTracker::wait()`.
125///
126/// Polls completion state and awaits notification without spawning tasks.
127pub struct CompletionFuture<'a> {
128 tracker: &'a CompletionTracker,
129}
130
131impl<'a> Future for CompletionFuture<'a> {
132 type Output = ();
133
134 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135 // Check if already complete
136 if self.tracker.is_complete() {
137 return Poll::Ready(());
138 }
139
140 // Register waker using enable() pattern - this is the proper way
141 // to integrate with Notify without storing futures
142 let notified = self.tracker.notify.notified();
143 // Pin it on the stack
144 futures::pin_mut!(notified);
145
146 // Enable the notified future to receive wake notifications
147 // This registers the current task's waker with the Notify
148 notified.as_mut().enable();
149
150 // Check again after registering (avoid race condition)
151 if self.tracker.is_complete() {
152 return Poll::Ready(());
153 }
154
155 // Now poll the notified future to properly register the waker
156 match notified.as_mut().poll(cx) {
157 Poll::Ready(()) => {
158 // Notification received, check completion again
159 if self.tracker.is_complete() {
160 Poll::Ready(())
161 } else {
162 // Not yet complete, wake self to retry
163 cx.waker().wake_by_ref();
164 Poll::Pending
165 }
166 }
167 Poll::Pending => Poll::Pending,
168 }
169 }
170}
171
172impl std::fmt::Debug for CompletionTracker {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 f.debug_struct("CompletionTracker")
175 .field("expected", &self.expected())
176 .field("completed", &self.completed())
177 .field("has_error", &self.error.lock().is_some())
178 .finish()
179 }
180}