pub mod error;
pub mod events;
pub mod exports;
use std::{
io::{Cursor, Read},
sync::Arc,
time::Duration,
};
use chrono::Utc;
use futures_util::{stream::StreamExt, SinkExt};
use tokio::{net::TcpStream, sync::Mutex, time::Instant};
use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
use tokio_util::sync::CancellationToken;
use url::Url;
use zstd::dict::DecoderDictionary;
use crate::{
error::{ConfigValidationError, ConnectionError, JetstreamEventError},
events::JetstreamEvent,
};
pub enum DefaultJetstreamEndpoints {
USEastOne,
USEastTwo,
USWestOne,
USWestTwo,
}
impl From<DefaultJetstreamEndpoints> for String {
fn from(endpoint: DefaultJetstreamEndpoints) -> Self {
match endpoint {
DefaultJetstreamEndpoints::USEastOne => {
"wss://jetstream1.us-east.bsky.network/subscribe".to_owned()
}
DefaultJetstreamEndpoints::USEastTwo => {
"wss://jetstream2.us-east.bsky.network/subscribe".to_owned()
}
DefaultJetstreamEndpoints::USWestOne => {
"wss://jetstream1.us-west.bsky.network/subscribe".to_owned()
}
DefaultJetstreamEndpoints::USWestTwo => {
"wss://jetstream2.us-west.bsky.network/subscribe".to_owned()
}
}
}
}
const MAX_WANTED_COLLECTIONS: usize = 100;
const MAX_WANTED_DIDS: usize = 10_000;
const JETSTREAM_ZSTD_DICTIONARY: &[u8] = include_bytes!("../zstd/dictionary");
pub type JetstreamReceiver = flume::Receiver<JetstreamEvent>;
type JetstreamSender = flume::Sender<JetstreamEvent>;
pub struct JetstreamConnector {
config: JetstreamConfig,
}
pub enum JetstreamCompression {
None,
Zstd,
}
impl From<JetstreamCompression> for bool {
fn from(compression: JetstreamCompression) -> Self {
match compression {
JetstreamCompression::None => false,
JetstreamCompression::Zstd => true,
}
}
}
pub struct JetstreamConfig {
pub endpoint: String,
pub wanted_collections: Vec<exports::Nsid>,
pub wanted_dids: Vec<exports::Did>,
pub compression: JetstreamCompression,
pub cursor: Option<chrono::DateTime<Utc>>,
pub max_retries: u32,
pub max_delay_ms: u64,
pub base_delay_ms: u64,
pub reset_retries_min_ms: u64,
}
impl Default for JetstreamConfig {
fn default() -> Self {
JetstreamConfig {
endpoint: DefaultJetstreamEndpoints::USEastOne.into(),
wanted_collections: Vec::new(),
wanted_dids: Vec::new(),
compression: JetstreamCompression::None,
cursor: None,
max_retries: 10,
max_delay_ms: 30_000,
base_delay_ms: 1_000,
reset_retries_min_ms: 30_000
}
}
}
impl JetstreamConfig {
pub fn construct_endpoint(&self, endpoint: &str) -> Result<Url, url::ParseError> {
let did_search_query = self
.wanted_dids
.iter()
.map(|s| ("wantedDids", s.to_string()));
let collection_search_query = self
.wanted_collections
.iter()
.map(|s| ("wantedCollections", s.to_string()));
let compression = (
"compress",
match self.compression {
JetstreamCompression::None => "false".to_owned(),
JetstreamCompression::Zstd => "true".to_owned(),
},
);
let cursor = self
.cursor
.map(|c| ("cursor", c.timestamp_micros().to_string()));
let params = did_search_query
.chain(collection_search_query)
.chain(std::iter::once(compression))
.chain(cursor)
.collect::<Vec<(&str, String)>>();
Url::parse_with_params(endpoint, params)
}
pub fn validate(&self) -> Result<(), ConfigValidationError> {
let collections = self.wanted_collections.len();
let dids = self.wanted_dids.len();
if collections > MAX_WANTED_COLLECTIONS {
return Err(ConfigValidationError::TooManyWantedCollections(collections));
}
if dids > MAX_WANTED_DIDS {
return Err(ConfigValidationError::TooManyDids(dids));
}
Ok(())
}
}
impl JetstreamConnector {
pub fn new(config: JetstreamConfig) -> Result<Self, ConfigValidationError> {
config.validate()?;
Ok(JetstreamConnector { config })
}
pub async fn connect(&self) -> Result<JetstreamReceiver, ConnectionError> {
self.config
.validate()
.map_err(ConnectionError::InvalidConfig)?;
let (send_channel, receive_channel) = flume::unbounded();
let configured_endpoint = self
.config
.construct_endpoint(&self.config.endpoint)
.map_err(ConnectionError::InvalidEndpoint)?;
let max_delay_ms = self.config.max_delay_ms;
let base_delay_ms = self.config.base_delay_ms;
let max_retries = self.config.max_retries;
let min_duration_before_retry_reset = Duration::from_millis(self.config.reset_retries_min_ms);
tokio::task::spawn(async move {
let mut retry_attempt = 0;
loop {
let dict = DecoderDictionary::copy(JETSTREAM_ZSTD_DICTIONARY);
if let Ok((ws_stream, _)) = connect_async(&configured_endpoint).await {
let now = Instant::now();
let _ = websocket_task(dict, ws_stream, send_channel.clone()).await;
let after_connection_closed = Instant::now();
if let Some(connection_alive_duration) = after_connection_closed.checked_duration_since(now) {
if connection_alive_duration > min_duration_before_retry_reset {
retry_attempt = 0
}
}
}
retry_attempt += 1;
if retry_attempt >= max_retries {
break;
}
let delay_ms = base_delay_ms * (2_u64.pow(retry_attempt));
log::error!("Connection failed, retrying in {delay_ms}ms...");
tokio::time::sleep(Duration::from_millis(delay_ms.min(max_delay_ms))).await;
log::info!("Attempting to reconnect...");
}
log::error!("Connection retries exhausted. Jetstream is disconnected.");
});
Ok(receive_channel)
}
}
async fn websocket_task(
dictionary: DecoderDictionary<'_>,
ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
send_channel: JetstreamSender,
) -> Result<(), JetstreamEventError> {
let (socket_write, mut socket_read) = ws.split();
let shared_socket_write = Arc::new(Mutex::new(socket_write));
let ping_cancellation_token = CancellationToken::new();
let mut ping_interval = tokio::time::interval(Duration::from_secs(30));
let ping_cancelled = ping_cancellation_token.clone();
let ping_shared_socket_write = shared_socket_write.clone();
tokio::spawn(async move {
loop {
ping_interval.tick().await;
let false = ping_cancelled.is_cancelled() else {
break;
};
log::trace!("Sending ping");
match ping_shared_socket_write
.lock()
.await
.send(Message::Ping("ping".as_bytes().to_vec()))
.await
{
Ok(_) => (),
Err(error) => {
log::error!("Ping failed: {error}");
break;
}
}
}
});
let mut closing_connection = false;
loop {
match socket_read.next().await {
Some(Ok(message)) => {
match message {
Message::Text(json) => {
let event = serde_json::from_str::<JetstreamEvent>(&json)
.map_err(JetstreamEventError::ReceivedMalformedJSON)?;
if send_channel.send(event).is_err() {
log::info!(
"All receivers for the Jetstream connection have been dropped, closing connection."
);
closing_connection = true;
}
}
Message::Binary(zstd_json) => {
let mut cursor = Cursor::new(zstd_json);
let mut decoder = zstd::stream::Decoder::with_prepared_dictionary(
&mut cursor,
&dictionary,
)
.map_err(JetstreamEventError::CompressionDictionaryError)?;
let mut json = String::new();
decoder
.read_to_string(&mut json)
.map_err(JetstreamEventError::CompressionDecoderError)?;
let event = serde_json::from_str::<JetstreamEvent>(&json)
.map_err(JetstreamEventError::ReceivedMalformedJSON)?;
if send_channel.send(event).is_err() {
log::info!(
"All receivers for the Jetstream connection have been dropped, closing connection..."
);
closing_connection = true;
}
}
Message::Ping(vec) => {
log::trace!("Ping recieved, responding");
_ = shared_socket_write
.lock()
.await
.send(Message::Pong(vec))
.await;
}
Message::Close(close_frame) => {
if let Some(close_frame) = close_frame {
let reason = close_frame.reason;
let code = close_frame.code;
log::trace!("Connection closed. Reason: {reason}, Code: {code}");
}
}
Message::Pong(pong) => {
let pong_payload =
String::from_utf8(pong).unwrap_or("Invalid payload".to_string());
log::trace!("Pong recieved. Payload: {pong_payload}");
}
Message::Frame(_) => (),
}
}
Some(Err(error)) => {
log::error!("Web socket error: {error}");
ping_cancellation_token.cancel();
closing_connection = true;
}
None => {
log::error!("No web socket result");
ping_cancellation_token.cancel();
closing_connection = true;
}
}
if closing_connection {
_ = shared_socket_write.lock().await.close().await;
return Ok(());
}
}
}