Skip to main content

atproto_tap/
stream.rs

1//! TAP event stream implementation.
2//!
3//! This module provides [`TapStream`], an async stream that yields TAP events
4//! with automatic connection management and reconnection handling.
5//!
6//! # Design
7//!
8//! The stream encapsulates all connection logic, allowing consumers to simply
9//! iterate over events using standard stream combinators or `tokio::select!`.
10//!
11//! Reconnection is handled automatically with exponential backoff. Parse errors
12//! are yielded as `Err` items but don't affect connection state - only connection
13//! errors trigger reconnection attempts.
14
15use crate::config::TapConfig;
16use crate::connection::TapConnection;
17use crate::errors::TapError;
18use crate::events::{TapEvent, extract_event_id};
19use futures::Stream;
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23use std::time::Duration;
24use tokio::sync::mpsc;
25
26/// An async stream of TAP events with automatic reconnection.
27///
28/// `TapStream` implements [`Stream`] and yields `Result<Arc<TapEvent>, TapError>`.
29/// Events are wrapped in `Arc` for efficient zero-cost sharing across consumers.
30///
31/// # Connection Management
32///
33/// The stream automatically:
34/// - Connects on first poll
35/// - Reconnects with exponential backoff on connection errors
36/// - Sends acknowledgments after parsing each message (if enabled)
37/// - Yields parse errors without affecting connection state
38///
39/// # Example
40///
41/// ```ignore
42/// use atproto_tap::{TapConfig, TapStream};
43/// use tokio_stream::StreamExt;
44///
45/// let config = TapConfig::builder()
46///     .hostname("localhost:2480")
47///     .build();
48///
49/// let mut stream = TapStream::new(config);
50///
51/// while let Some(result) = stream.next().await {
52///     match result {
53///         Ok(event) => println!("Event: {:?}", event),
54///         Err(e) => eprintln!("Error: {}", e),
55///     }
56/// }
57/// ```
58pub struct TapStream {
59    /// Receiver for events from the background task.
60    receiver: mpsc::Receiver<Result<Arc<TapEvent>, TapError>>,
61    /// Handle to request stream closure.
62    close_sender: Option<mpsc::Sender<()>>,
63    /// Whether the stream has been closed.
64    closed: bool,
65}
66
67impl TapStream {
68    /// Create a new TAP stream with the given configuration.
69    ///
70    /// The stream will start connecting immediately in a background task.
71    pub fn new(config: TapConfig) -> Self {
72        // Channel for events - buffer a few to handle bursts
73        let (event_tx, event_rx) = mpsc::channel(config.channel_buffer_size);
74        // Channel for close signal
75        let (close_tx, close_rx) = mpsc::channel(1);
76
77        // Spawn background task to manage connection
78        tokio::spawn(connection_task(config, event_tx, close_rx));
79
80        Self {
81            receiver: event_rx,
82            close_sender: Some(close_tx),
83            closed: false,
84        }
85    }
86
87    /// Close the stream and release resources.
88    ///
89    /// After calling this, the stream will yield `None` on the next poll.
90    pub async fn close(&mut self) {
91        if let Some(sender) = self.close_sender.take() {
92            // Signal the background task to close
93            let _ = sender.send(()).await;
94        }
95        self.closed = true;
96    }
97
98    /// Returns true if the stream is closed.
99    pub fn is_closed(&self) -> bool {
100        self.closed
101    }
102}
103
104impl Stream for TapStream {
105    type Item = Result<Arc<TapEvent>, TapError>;
106
107    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
108        if self.closed {
109            return Poll::Ready(None);
110        }
111
112        self.receiver.poll_recv(cx)
113    }
114}
115
116impl Drop for TapStream {
117    fn drop(&mut self) {
118        // Drop the close_sender to signal the background task
119        self.close_sender.take();
120        tracing::debug!("TapStream dropped");
121    }
122}
123
124/// Background task that manages the WebSocket connection.
125async fn connection_task(
126    config: TapConfig,
127    event_tx: mpsc::Sender<Result<Arc<TapEvent>, TapError>>,
128    mut close_rx: mpsc::Receiver<()>,
129) {
130    let mut current_reconnect_delay = config.initial_reconnect_delay;
131    let mut attempt: u32 = 0;
132
133    loop {
134        // Check for close signal
135        if close_rx.try_recv().is_ok() {
136            tracing::debug!("Connection task received close signal");
137            break;
138        }
139
140        // Try to connect
141        tracing::debug!(attempt, hostname = %config.hostname, "Connecting to TAP service");
142        let conn_result = TapConnection::connect(&config).await;
143
144        match conn_result {
145            Ok(mut conn) => {
146                tracing::info!(hostname = %config.hostname, "TAP stream connected");
147                // Reset reconnection state on successful connect
148                current_reconnect_delay = config.initial_reconnect_delay;
149                attempt = 0;
150
151                // Event loop for this connection
152                loop {
153                    tokio::select! {
154                        biased;
155
156                        _ = close_rx.recv() => {
157                            tracing::debug!("Connection task received close signal during receive");
158                            let _ = conn.close().await;
159                            return;
160                        }
161
162                        recv_result = conn.recv() => {
163                            match recv_result {
164                                Ok(Some(msg)) => {
165                                    // Parse the message
166                                    match serde_json::from_str::<TapEvent>(&msg) {
167                                        Ok(event) => {
168                                            let event_id = event.id();
169
170                                            // Send ack if enabled (before sending event to channel)
171                                            if config.send_acks
172                                                && let Err(err) = conn.send_ack(event_id).await
173                                            {
174                                                tracing::warn!(error = %err, "Failed to send ack");
175                                                // Don't break connection for ack errors
176                                            }
177
178                                            // Send event to channel
179                                            let event = Arc::new(event);
180                                            if event_tx.send(Ok(event)).await.is_err() {
181                                                // Receiver dropped, exit task
182                                                tracing::debug!("Event receiver dropped, closing connection");
183                                                let _ = conn.close().await;
184                                                return;
185                                            }
186                                        }
187                                        Err(err) => {
188                                            // Parse errors don't affect connection
189                                            tracing::warn!(error = %err, "Failed to parse TAP message");
190
191                                            // Try to extract just the ID using fallback parser
192                                            // so we can still ack the message even if full parsing fails
193                                            if config.send_acks {
194                                                if let Some(event_id) = extract_event_id(&msg) {
195                                                    tracing::debug!(event_id, "Extracted event ID via fallback parser");
196                                                    if let Err(ack_err) = conn.send_ack(event_id).await {
197                                                        tracing::warn!(error = %ack_err, "Failed to send ack for unparseable message");
198                                                    }
199                                                } else {
200                                                    tracing::warn!("Could not extract event ID from unparseable message");
201                                                }
202                                            }
203
204                                            if event_tx.send(Err(TapError::ParseError(err.to_string()))).await.is_err() {
205                                                tracing::debug!("Event receiver dropped, closing connection");
206                                                let _ = conn.close().await;
207                                                return;
208                                            }
209                                        }
210                                    }
211                                }
212                                Ok(None) => {
213                                    // Connection closed by server
214                                    tracing::debug!("TAP connection closed by server");
215                                    break;
216                                }
217                                Err(err) => {
218                                    // Connection error
219                                    tracing::warn!(error = %err, "TAP connection error");
220                                    break;
221                                }
222                            }
223                        }
224                    }
225                }
226            }
227            Err(err) => {
228                tracing::warn!(error = %err, attempt, "Failed to connect to TAP service");
229            }
230        }
231
232        // Increment attempt counter
233        attempt += 1;
234
235        // Check if we've exceeded max attempts
236        if let Some(max) = config.max_reconnect_attempts
237            && attempt >= max
238        {
239            tracing::error!(attempts = attempt, "Max reconnection attempts exceeded");
240            let _ = event_tx
241                .send(Err(TapError::MaxReconnectAttemptsExceeded(attempt)))
242                .await;
243            break;
244        }
245
246        // Wait before reconnecting with exponential backoff
247        tracing::debug!(
248            delay_ms = current_reconnect_delay.as_millis(),
249            attempt,
250            "Waiting before reconnection"
251        );
252
253        tokio::select! {
254            _ = close_rx.recv() => {
255                tracing::debug!("Connection task received close signal during backoff");
256                return;
257            }
258            _ = tokio::time::sleep(current_reconnect_delay) => {
259                // Update delay for next attempt
260                current_reconnect_delay = Duration::from_secs_f64(
261                    (current_reconnect_delay.as_secs_f64() * config.reconnect_backoff_multiplier)
262                        .min(config.max_reconnect_delay.as_secs_f64()),
263                );
264            }
265        }
266    }
267
268    tracing::debug!("Connection task exiting");
269}
270
271/// Create a new TAP stream with the given configuration.
272pub fn connect(config: TapConfig) -> TapStream {
273    TapStream::new(config)
274}
275
276/// Create a new TAP stream connected to the given hostname.
277///
278/// Uses default configuration values.
279pub fn connect_to(hostname: &str) -> TapStream {
280    TapStream::new(TapConfig::new(hostname))
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    #[test]
288    fn test_stream_initial_state() {
289        // Note: This test doesn't actually poll the stream, just checks initial state
290        // Creating a TapStream requires a tokio runtime for the spawn
291    }
292
293    #[tokio::test]
294    async fn test_stream_close() {
295        let mut stream = TapStream::new(TapConfig::new("localhost:9999"));
296        assert!(!stream.is_closed());
297        stream.close().await;
298        assert!(stream.is_closed());
299    }
300
301    #[test]
302    fn test_connect_functions() {
303        // These just create configs, actual connection happens in background task
304        // We can't test without a runtime, so just verify the types compile
305        let _ = TapConfig::new("localhost:2480");
306    }
307
308    #[test]
309    fn test_reconnect_delay_calculation() {
310        // Test the delay calculation logic
311        let initial = Duration::from_secs(1);
312        let max = Duration::from_secs(10);
313        let multiplier = 2.0;
314
315        let mut delay = initial;
316        assert_eq!(delay, Duration::from_secs(1));
317
318        delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64()));
319        assert_eq!(delay, Duration::from_secs(2));
320
321        delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64()));
322        assert_eq!(delay, Duration::from_secs(4));
323
324        delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64()));
325        assert_eq!(delay, Duration::from_secs(8));
326
327        delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64()));
328        assert_eq!(delay, Duration::from_secs(10)); // Capped at max
329    }
330}