1pub mod error;
2pub mod events;
3pub mod exports;
4
5use std::{
6 io::{Cursor, Read},
7 sync::Arc,
8 time::Duration,
9};
10
11use chrono::Utc;
12use futures_util::{stream::StreamExt, SinkExt};
13use tokio::{net::TcpStream, sync::Mutex, time::Instant};
14use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
15use tokio_util::sync::CancellationToken;
16use url::Url;
17use zstd::dict::DecoderDictionary;
18
19use crate::{
20 error::{ConfigValidationError, ConnectionError, JetstreamEventError},
21 events::JetstreamEvent,
22};
23
24pub enum DefaultJetstreamEndpoints {
29 USEastOne,
31 USEastTwo,
33 USWestOne,
35 USWestTwo,
37}
38
39impl From<DefaultJetstreamEndpoints> for String {
40 fn from(endpoint: DefaultJetstreamEndpoints) -> Self {
41 match endpoint {
42 DefaultJetstreamEndpoints::USEastOne => {
43 "wss://jetstream1.us-east.bsky.network/subscribe".to_owned()
44 }
45 DefaultJetstreamEndpoints::USEastTwo => {
46 "wss://jetstream2.us-east.bsky.network/subscribe".to_owned()
47 }
48 DefaultJetstreamEndpoints::USWestOne => {
49 "wss://jetstream1.us-west.bsky.network/subscribe".to_owned()
50 }
51 DefaultJetstreamEndpoints::USWestTwo => {
52 "wss://jetstream2.us-west.bsky.network/subscribe".to_owned()
53 }
54 }
55 }
56}
57
58const MAX_WANTED_COLLECTIONS: usize = 100;
60const MAX_WANTED_DIDS: usize = 10_000;
62
63const JETSTREAM_ZSTD_DICTIONARY: &[u8] = include_bytes!("../zstd/dictionary");
67
68pub type JetstreamReceiver = flume::Receiver<JetstreamEvent>;
70
71type JetstreamSender = flume::Sender<JetstreamEvent>;
73
74pub struct JetstreamConnector {
77 config: JetstreamConfig,
79}
80
81pub enum JetstreamCompression {
82 None,
84 Zstd,
87}
88
89impl From<JetstreamCompression> for bool {
90 fn from(compression: JetstreamCompression) -> Self {
91 match compression {
92 JetstreamCompression::None => false,
93 JetstreamCompression::Zstd => true,
94 }
95 }
96}
97
98pub struct JetstreamConfig {
99 pub endpoint: String,
102 pub wanted_collections: Vec<exports::Nsid>,
110 pub wanted_dids: Vec<exports::Did>,
114 pub compression: JetstreamCompression,
116 pub cursor: Option<chrono::DateTime<Utc>>,
123
124 pub max_retries: u32,
126
127 pub max_delay_ms: u64,
129
130 pub base_delay_ms: u64,
132
133 pub reset_retries_min_ms: u64,
135}
136
137impl Default for JetstreamConfig {
138 fn default() -> Self {
139 JetstreamConfig {
140 endpoint: DefaultJetstreamEndpoints::USEastOne.into(),
141 wanted_collections: Vec::new(),
142 wanted_dids: Vec::new(),
143 compression: JetstreamCompression::None,
144 cursor: None,
145 max_retries: 10,
146 max_delay_ms: 30_000,
147 base_delay_ms: 1_000,
148 reset_retries_min_ms: 30_000
149 }
150 }
151}
152
153impl JetstreamConfig {
154 pub fn construct_endpoint(&self, endpoint: &str) -> Result<Url, url::ParseError> {
156 let did_search_query = self
157 .wanted_dids
158 .iter()
159 .map(|s| ("wantedDids", s.to_string()));
160
161 let collection_search_query = self
162 .wanted_collections
163 .iter()
164 .map(|s| ("wantedCollections", s.to_string()));
165
166 let compression = (
167 "compress",
168 match self.compression {
169 JetstreamCompression::None => "false".to_owned(),
170 JetstreamCompression::Zstd => "true".to_owned(),
171 },
172 );
173
174 let cursor = self
175 .cursor
176 .map(|c| ("cursor", c.timestamp_micros().to_string()));
177
178 let params = did_search_query
179 .chain(collection_search_query)
180 .chain(std::iter::once(compression))
181 .chain(cursor)
182 .collect::<Vec<(&str, String)>>();
183
184 Url::parse_with_params(endpoint, params)
185 }
186
187 pub fn validate(&self) -> Result<(), ConfigValidationError> {
195 let collections = self.wanted_collections.len();
196 let dids = self.wanted_dids.len();
197
198 if collections > MAX_WANTED_COLLECTIONS {
199 return Err(ConfigValidationError::TooManyWantedCollections(collections));
200 }
201
202 if dids > MAX_WANTED_DIDS {
203 return Err(ConfigValidationError::TooManyDids(dids));
204 }
205
206 Ok(())
207 }
208}
209
210impl JetstreamConnector {
211 pub fn new(config: JetstreamConfig) -> Result<Self, ConfigValidationError> {
215 config.validate()?;
217 Ok(JetstreamConnector { config })
218 }
219
220 pub async fn connect(&self) -> Result<JetstreamReceiver, ConnectionError> {
225 self.config
227 .validate()
228 .map_err(ConnectionError::InvalidConfig)?;
229
230 let (send_channel, receive_channel) = flume::unbounded();
232
233 let configured_endpoint = self
234 .config
235 .construct_endpoint(&self.config.endpoint)
236 .map_err(ConnectionError::InvalidEndpoint)?;
237
238 let max_delay_ms = self.config.max_delay_ms;
239 let base_delay_ms = self.config.base_delay_ms;
240 let max_retries = self.config.max_retries;
241 let min_duration_before_retry_reset = Duration::from_millis(self.config.reset_retries_min_ms);
242
243 tokio::task::spawn(async move {
244
245 let mut retry_attempt = 0;
246
247 loop {
248 let dict = DecoderDictionary::copy(JETSTREAM_ZSTD_DICTIONARY);
249
250 if let Ok((ws_stream, _)) = connect_async(&configured_endpoint).await {
251 let now = Instant::now();
252 let _ = websocket_task(dict, ws_stream, send_channel.clone()).await;
253 let after_connection_closed = Instant::now();
254 if let Some(connection_alive_duration) = after_connection_closed.checked_duration_since(now) {
255 if connection_alive_duration > min_duration_before_retry_reset {
256 retry_attempt = 0
257 }
258 }
259 }
260
261 retry_attempt += 1;
262
263 if retry_attempt >= max_retries {
264 break;
265 }
266
267 let delay_ms = base_delay_ms * (2_u64.pow(retry_attempt));
269 log::error!("Connection failed, retrying in {delay_ms}ms...");
270 tokio::time::sleep(Duration::from_millis(delay_ms.min(max_delay_ms))).await;
271 log::info!("Attempting to reconnect...");
272 }
273 log::error!("Connection retries exhausted. Jetstream is disconnected.");
274 });
275
276 Ok(receive_channel)
277 }
278}
279
280async fn websocket_task(
283 dictionary: DecoderDictionary<'_>,
284 ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
285 send_channel: JetstreamSender,
286) -> Result<(), JetstreamEventError> {
287 let (socket_write, mut socket_read) = ws.split();
289 let shared_socket_write = Arc::new(Mutex::new(socket_write));
290
291 let ping_cancellation_token = CancellationToken::new();
292 let mut ping_interval = tokio::time::interval(Duration::from_secs(30));
293 let ping_cancelled = ping_cancellation_token.clone();
294 let ping_shared_socket_write = shared_socket_write.clone();
295 tokio::spawn(async move {
296 loop {
297 ping_interval.tick().await;
298 let false = ping_cancelled.is_cancelled() else {
299 break;
300 };
301 log::trace!("Sending ping");
302 match ping_shared_socket_write
303 .lock()
304 .await
305 .send(Message::Ping("ping".as_bytes().to_vec()))
306 .await
307 {
308 Ok(_) => (),
309 Err(error) => {
310 log::error!("Ping failed: {error}");
311 break;
312 }
313 }
314 }
315 });
316
317 let mut closing_connection = false;
318 loop {
319 match socket_read.next().await {
320 Some(Ok(message)) => {
321 match message {
322 Message::Text(json) => {
323 let event = serde_json::from_str::<JetstreamEvent>(&json)
324 .map_err(JetstreamEventError::ReceivedMalformedJSON)?;
325
326 if send_channel.send(event).is_err() {
327 log::info!(
330 "All receivers for the Jetstream connection have been dropped, closing connection."
331 );
332 closing_connection = true;
333 }
334 }
335 Message::Binary(zstd_json) => {
336 let mut cursor = Cursor::new(zstd_json);
337 let mut decoder = zstd::stream::Decoder::with_prepared_dictionary(
338 &mut cursor,
339 &dictionary,
340 )
341 .map_err(JetstreamEventError::CompressionDictionaryError)?;
342
343 let mut json = String::new();
344 decoder
345 .read_to_string(&mut json)
346 .map_err(JetstreamEventError::CompressionDecoderError)?;
347
348 let event = serde_json::from_str::<JetstreamEvent>(&json)
349 .map_err(JetstreamEventError::ReceivedMalformedJSON)?;
350
351 if send_channel.send(event).is_err() {
352 log::info!(
355 "All receivers for the Jetstream connection have been dropped, closing connection..."
356 );
357 closing_connection = true;
358 }
359 }
360 Message::Ping(vec) => {
361 log::trace!("Ping recieved, responding");
362 _ = shared_socket_write
363 .lock()
364 .await
365 .send(Message::Pong(vec))
366 .await;
367 }
368 Message::Close(close_frame) => {
369 if let Some(close_frame) = close_frame {
370 let reason = close_frame.reason;
371 let code = close_frame.code;
372 log::trace!("Connection closed. Reason: {reason}, Code: {code}");
373 }
374 }
375 Message::Pong(pong) => {
376 let pong_payload =
377 String::from_utf8(pong).unwrap_or("Invalid payload".to_string());
378 log::trace!("Pong recieved. Payload: {pong_payload}");
379 }
380 Message::Frame(_) => (),
381 }
382 }
383 Some(Err(error)) => {
384 log::error!("Web socket error: {error}");
385 ping_cancellation_token.cancel();
386 closing_connection = true;
387 }
388 None => {
389 log::error!("No web socket result");
390 ping_cancellation_token.cancel();
391 closing_connection = true;
392 }
393 }
394 if closing_connection {
395 _ = shared_socket_write.lock().await.close().await;
396 return Ok(());
397 }
398 }
399}