Skip to main content

pipedream_rs/
subscription.rs

1use std::marker::PhantomData;
2use std::sync::Arc;
3use tokio::sync::mpsc;
4
5use crate::envelope::Envelope;
6use crate::error::RelayError;
7use crate::tracker::CompletionTracker;
8
9/// A typed subscription to a Relay.
10///
11/// Receives messages of type `T` from the stream.
12/// Type filtering happens locally - wrong types are skipped, not errored.
13///
14/// ## Observable Delivery
15///
16/// Subscriptions receive messages via try_send with buffering (default: 65536).
17/// If a subscriber's buffer fills, messages are dropped and observable via `Dropped` events.
18/// No backpressure - senders never block waiting for slow consumers.
19///
20/// ## Completion Semantics
21///
22/// There are two kinds of subscriptions:
23/// - **Tracked** (used by sink/tap): Participates in completion tracking.
24///   Signals completion for wrong-type messages, fails tracker on Drop.
25/// - **Untracked** (from `subscribe()`): Does NOT participate in tracking.
26///   Still receives all messages, but sender doesn't wait for completion.
27///
28/// Tracked subscriptions are created internally by sink/tap handlers.
29/// Raw subscriptions from `subscribe()` are untracked.
30pub struct Subscription<T> {
31    rx: mpsc::Receiver<Envelope>,
32    /// Tracker for the message currently being processed
33    current_tracker: Option<Arc<CompletionTracker>>,
34    /// Message ID of the message currently being processed
35    current_msg_id: Option<u64>,
36    /// Whether this subscription participates in completion tracking.
37    /// Set to true for sink/tap handlers, false for raw subscribe().
38    tracked: bool,
39    _marker: PhantomData<T>,
40}
41
42impl<T: 'static + Send + Sync> Subscription<T> {
43    /// Create an untracked subscription (for raw subscribe()).
44    /// Does NOT participate in completion tracking.
45    pub(crate) fn new(rx: mpsc::Receiver<Envelope>) -> Self {
46        Self {
47            rx,
48            current_tracker: None,
49            current_msg_id: None,
50            tracked: false,
51            _marker: PhantomData,
52        }
53    }
54
55    /// Create a tracked subscription (for sink/tap handlers).
56    /// Participates in completion tracking - signals completion for
57    /// wrong-type messages and fails tracker on Drop.
58    pub(crate) fn new_tracked(rx: mpsc::Receiver<Envelope>) -> Self {
59        Self {
60            rx,
61            current_tracker: None,
62            current_msg_id: None,
63            tracked: true,
64            _marker: PhantomData,
65        }
66    }
67
68    /// Get the completion tracker for the current message.
69    ///
70    /// Returns `None` if no message is being processed or if the message
71    /// was sent without completion tracking (fire-and-forget).
72    ///
73    /// Handlers should call `tracker.complete_one()` on success or
74    /// `tracker.fail(error)` on failure.
75    pub fn current_tracker(&self) -> Option<Arc<CompletionTracker>> {
76        self.current_tracker.clone()
77    }
78
79    /// Get the message ID of the current message being processed.
80    ///
81    /// Returns `None` if no message is being processed.
82    pub fn current_msg_id(&self) -> Option<u64> {
83        self.current_msg_id
84    }
85
86    /// Clear the current tracker without signaling completion.
87    ///
88    /// Called by handlers after they've explicitly completed the tracker.
89    /// This prevents double-completion and ensures Drop doesn't fail
90    /// an already-completed message.
91    pub fn clear_tracker(&mut self) {
92        self.current_tracker = None;
93        self.current_msg_id = None;
94    }
95
96    /// Receive the next message of type T.
97    ///
98    /// Returns `None` when the stream is closed.
99    /// Messages of other types are skipped automatically.
100    ///
101    /// **Important**: Handlers must call `complete_one()` or `fail()` on the
102    /// tracker before calling `recv()` again. This method does NOT auto-complete
103    /// for messages of the matching type.
104    pub async fn recv(&mut self) -> Option<Arc<T>> {
105        // Clear previous message state (handler should have already completed it)
106        self.current_tracker = None;
107        self.current_msg_id = None;
108
109        loop {
110            match self.rx.recv().await {
111                Some(env) => {
112                    if let Some(value) = env.downcast::<T>() {
113                        // Store tracker for this message
114                        self.current_tracker = env.tracker();
115                        self.current_msg_id = Some(env.msg_id());
116                        return Some(value);
117                    }
118                    // Wrong type - this handler can't process it.
119                    // Only tracked subscriptions (sink/tap) signal completion.
120                    if self.tracked {
121                        if let Some(tracker) = env.tracker() {
122                            tracker.complete_one();
123                        }
124                    }
125                }
126                None => return None, // Channel closed
127            }
128        }
129    }
130
131    /// Try to receive a message without waiting.
132    ///
133    /// Returns `None` if no message is available or the stream is closed.
134    ///
135    /// **Important**: Handlers must call `complete_one()` or `fail()` on the
136    /// tracker before calling `try_recv()` again. This method does NOT auto-complete
137    /// for messages of the matching type.
138    pub fn try_recv(&mut self) -> Option<Arc<T>> {
139        // Clear previous message state (handler should have already completed it)
140        self.current_tracker = None;
141        self.current_msg_id = None;
142
143        loop {
144            match self.rx.try_recv() {
145                Ok(env) => {
146                    if let Some(value) = env.downcast::<T>() {
147                        // Store tracker for this message
148                        self.current_tracker = env.tracker();
149                        self.current_msg_id = Some(env.msg_id());
150                        return Some(value);
151                    }
152                    // Wrong type - only tracked subscriptions signal completion
153                    if self.tracked {
154                        if let Some(tracker) = env.tracker() {
155                            tracker.complete_one();
156                        }
157                    }
158                }
159                Err(_) => return None,
160            }
161        }
162    }
163}
164
165impl<T> Drop for Subscription<T> {
166    fn drop(&mut self) {
167        // Safety net: only for tracked subscriptions.
168        // If dropped while processing and handler didn't complete, fail tracker.
169        // Untracked subscriptions don't participate in completion tracking.
170        if self.tracked {
171            if let Some(tracker) = self.current_tracker.take() {
172                let error = RelayError::new(
173                    self.current_msg_id.unwrap_or(0),
174                    std::io::Error::new(
175                        std::io::ErrorKind::Interrupted,
176                        "subscription dropped while processing message",
177                    ),
178                    "subscription",
179                );
180                tracker.fail(error);
181            }
182        }
183    }
184}
185
186impl<T> std::fmt::Debug for Subscription<T> {
187    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188        f.debug_struct("Subscription")
189            .field("type", &std::any::type_name::<T>())
190            .field("tracked", &self.tracked)
191            .field("has_current_tracker", &self.current_tracker.is_some())
192            .finish()
193    }
194}