atproto_tap/stream.rs
1//! TAP event stream implementation.
2//!
3//! This module provides [`TapStream`], an async stream that yields TAP events
4//! with automatic connection management and reconnection handling.
5//!
6//! # Design
7//!
8//! The stream encapsulates all connection logic, allowing consumers to simply
9//! iterate over events using standard stream combinators or `tokio::select!`.
10//!
11//! Reconnection is handled automatically with exponential backoff. Parse errors
12//! are yielded as `Err` items but don't affect connection state - only connection
13//! errors trigger reconnection attempts.
14
15use crate::config::TapConfig;
16use crate::connection::TapConnection;
17use crate::errors::TapError;
18use crate::events::{TapEvent, extract_event_id};
19use futures::Stream;
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23use std::time::Duration;
24use tokio::sync::mpsc;
25
26/// An async stream of TAP events with automatic reconnection.
27///
28/// `TapStream` implements [`Stream`] and yields `Result<Arc<TapEvent>, TapError>`.
29/// Events are wrapped in `Arc` for efficient zero-cost sharing across consumers.
30///
31/// # Connection Management
32///
33/// The stream automatically:
34/// - Connects on first poll
35/// - Reconnects with exponential backoff on connection errors
36/// - Sends acknowledgments after parsing each message (if enabled)
37/// - Yields parse errors without affecting connection state
38///
39/// # Example
40///
41/// ```ignore
42/// use atproto_tap::{TapConfig, TapStream};
43/// use tokio_stream::StreamExt;
44///
45/// let config = TapConfig::builder()
46/// .hostname("localhost:2480")
47/// .build();
48///
49/// let mut stream = TapStream::new(config);
50///
51/// while let Some(result) = stream.next().await {
52/// match result {
53/// Ok(event) => println!("Event: {:?}", event),
54/// Err(e) => eprintln!("Error: {}", e),
55/// }
56/// }
57/// ```
58pub struct TapStream {
59 /// Receiver for events from the background task.
60 receiver: mpsc::Receiver<Result<Arc<TapEvent>, TapError>>,
61 /// Handle to request stream closure.
62 close_sender: Option<mpsc::Sender<()>>,
63 /// Whether the stream has been closed.
64 closed: bool,
65}
66
67impl TapStream {
68 /// Create a new TAP stream with the given configuration.
69 ///
70 /// The stream will start connecting immediately in a background task.
71 pub fn new(config: TapConfig) -> Self {
72 // Channel for events - buffer a few to handle bursts
73 let (event_tx, event_rx) = mpsc::channel(config.channel_buffer_size);
74 // Channel for close signal
75 let (close_tx, close_rx) = mpsc::channel(1);
76
77 // Spawn background task to manage connection
78 tokio::spawn(connection_task(config, event_tx, close_rx));
79
80 Self {
81 receiver: event_rx,
82 close_sender: Some(close_tx),
83 closed: false,
84 }
85 }
86
87 /// Close the stream and release resources.
88 ///
89 /// After calling this, the stream will yield `None` on the next poll.
90 pub async fn close(&mut self) {
91 if let Some(sender) = self.close_sender.take() {
92 // Signal the background task to close
93 let _ = sender.send(()).await;
94 }
95 self.closed = true;
96 }
97
98 /// Returns true if the stream is closed.
99 pub fn is_closed(&self) -> bool {
100 self.closed
101 }
102}
103
104impl Stream for TapStream {
105 type Item = Result<Arc<TapEvent>, TapError>;
106
107 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
108 if self.closed {
109 return Poll::Ready(None);
110 }
111
112 self.receiver.poll_recv(cx)
113 }
114}
115
116impl Drop for TapStream {
117 fn drop(&mut self) {
118 // Drop the close_sender to signal the background task
119 self.close_sender.take();
120 tracing::debug!("TapStream dropped");
121 }
122}
123
124/// Background task that manages the WebSocket connection.
125async fn connection_task(
126 config: TapConfig,
127 event_tx: mpsc::Sender<Result<Arc<TapEvent>, TapError>>,
128 mut close_rx: mpsc::Receiver<()>,
129) {
130 let mut current_reconnect_delay = config.initial_reconnect_delay;
131 let mut attempt: u32 = 0;
132
133 loop {
134 // Check for close signal
135 if close_rx.try_recv().is_ok() {
136 tracing::debug!("Connection task received close signal");
137 break;
138 }
139
140 // Try to connect
141 tracing::debug!(attempt, hostname = %config.hostname, "Connecting to TAP service");
142 let conn_result = TapConnection::connect(&config).await;
143
144 match conn_result {
145 Ok(mut conn) => {
146 tracing::info!(hostname = %config.hostname, "TAP stream connected");
147 // Reset reconnection state on successful connect
148 current_reconnect_delay = config.initial_reconnect_delay;
149 attempt = 0;
150
151 // Event loop for this connection
152 loop {
153 tokio::select! {
154 biased;
155
156 _ = close_rx.recv() => {
157 tracing::debug!("Connection task received close signal during receive");
158 let _ = conn.close().await;
159 return;
160 }
161
162 recv_result = conn.recv() => {
163 match recv_result {
164 Ok(Some(msg)) => {
165 // Parse the message
166 match serde_json::from_str::<TapEvent>(&msg) {
167 Ok(event) => {
168 let event_id = event.id();
169
170 // Send ack if enabled (before sending event to channel)
171 if config.send_acks
172 && let Err(err) = conn.send_ack(event_id).await
173 {
174 tracing::warn!(error = %err, "Failed to send ack");
175 // Don't break connection for ack errors
176 }
177
178 // Send event to channel
179 let event = Arc::new(event);
180 if event_tx.send(Ok(event)).await.is_err() {
181 // Receiver dropped, exit task
182 tracing::debug!("Event receiver dropped, closing connection");
183 let _ = conn.close().await;
184 return;
185 }
186 }
187 Err(err) => {
188 // Parse errors don't affect connection
189 tracing::warn!(error = %err, "Failed to parse TAP message");
190
191 // Try to extract just the ID using fallback parser
192 // so we can still ack the message even if full parsing fails
193 if config.send_acks {
194 if let Some(event_id) = extract_event_id(&msg) {
195 tracing::debug!(event_id, "Extracted event ID via fallback parser");
196 if let Err(ack_err) = conn.send_ack(event_id).await {
197 tracing::warn!(error = %ack_err, "Failed to send ack for unparseable message");
198 }
199 } else {
200 tracing::warn!("Could not extract event ID from unparseable message");
201 }
202 }
203
204 if event_tx.send(Err(TapError::ParseError(err.to_string()))).await.is_err() {
205 tracing::debug!("Event receiver dropped, closing connection");
206 let _ = conn.close().await;
207 return;
208 }
209 }
210 }
211 }
212 Ok(None) => {
213 // Connection closed by server
214 tracing::debug!("TAP connection closed by server");
215 break;
216 }
217 Err(err) => {
218 // Connection error
219 tracing::warn!(error = %err, "TAP connection error");
220 break;
221 }
222 }
223 }
224 }
225 }
226 }
227 Err(err) => {
228 tracing::warn!(error = %err, attempt, "Failed to connect to TAP service");
229 }
230 }
231
232 // Increment attempt counter
233 attempt += 1;
234
235 // Check if we've exceeded max attempts
236 if let Some(max) = config.max_reconnect_attempts
237 && attempt >= max
238 {
239 tracing::error!(attempts = attempt, "Max reconnection attempts exceeded");
240 let _ = event_tx
241 .send(Err(TapError::MaxReconnectAttemptsExceeded(attempt)))
242 .await;
243 break;
244 }
245
246 // Wait before reconnecting with exponential backoff
247 tracing::debug!(
248 delay_ms = current_reconnect_delay.as_millis(),
249 attempt,
250 "Waiting before reconnection"
251 );
252
253 tokio::select! {
254 _ = close_rx.recv() => {
255 tracing::debug!("Connection task received close signal during backoff");
256 return;
257 }
258 _ = tokio::time::sleep(current_reconnect_delay) => {
259 // Update delay for next attempt
260 current_reconnect_delay = Duration::from_secs_f64(
261 (current_reconnect_delay.as_secs_f64() * config.reconnect_backoff_multiplier)
262 .min(config.max_reconnect_delay.as_secs_f64()),
263 );
264 }
265 }
266 }
267
268 tracing::debug!("Connection task exiting");
269}
270
271/// Create a new TAP stream with the given configuration.
272pub fn connect(config: TapConfig) -> TapStream {
273 TapStream::new(config)
274}
275
276/// Create a new TAP stream connected to the given hostname.
277///
278/// Uses default configuration values.
279pub fn connect_to(hostname: &str) -> TapStream {
280 TapStream::new(TapConfig::new(hostname))
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
288 fn test_stream_initial_state() {
289 // Note: This test doesn't actually poll the stream, just checks initial state
290 // Creating a TapStream requires a tokio runtime for the spawn
291 }
292
293 #[tokio::test]
294 async fn test_stream_close() {
295 let mut stream = TapStream::new(TapConfig::new("localhost:9999"));
296 assert!(!stream.is_closed());
297 stream.close().await;
298 assert!(stream.is_closed());
299 }
300
301 #[test]
302 fn test_connect_functions() {
303 // These just create configs, actual connection happens in background task
304 // We can't test without a runtime, so just verify the types compile
305 let _ = TapConfig::new("localhost:2480");
306 }
307
308 #[test]
309 fn test_reconnect_delay_calculation() {
310 // Test the delay calculation logic
311 let initial = Duration::from_secs(1);
312 let max = Duration::from_secs(10);
313 let multiplier = 2.0;
314
315 let mut delay = initial;
316 assert_eq!(delay, Duration::from_secs(1));
317
318 delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64()));
319 assert_eq!(delay, Duration::from_secs(2));
320
321 delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64()));
322 assert_eq!(delay, Duration::from_secs(4));
323
324 delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64()));
325 assert_eq!(delay, Duration::from_secs(8));
326
327 delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64()));
328 assert_eq!(delay, Duration::from_secs(10)); // Capped at max
329 }
330}