use crate::{LaserstreamConfig, LaserstreamError};
use async_stream::try_stream;
use futures::TryStreamExt;
use futures_channel::mpsc as futures_mpsc;
use futures_util::{sink::SinkExt, Stream, StreamExt};
use std::{pin::Pin, time::Duration};
use tokio::time::sleep;
use tonic::Status;
use tracing::{error, instrument, warn};
use uuid;
use yellowstone_grpc_client::{ClientTlsConfig, GeyserGrpcClient};
use yellowstone_grpc_proto::geyser::{
subscribe_update::UpdateOneof, SubscribeRequest, SubscribeRequestFilterSlots,
SubscribeRequestPing, SubscribeUpdate,
};
const HARD_CAP_RECONNECT_ATTEMPTS: u32 = (20 * 60) / 5; const FIXED_RECONNECT_INTERVAL_MS: u64 = 5000;
#[instrument(skip(config, request))]
pub fn subscribe(
config: LaserstreamConfig,
request: SubscribeRequest,
) -> impl Stream<Item = Result<SubscribeUpdate, LaserstreamError>> {
try_stream! {
let mut reconnect_attempts = 0;
let mut tracked_slot: u64 = 0;
let effective_max_attempts = config
.max_reconnect_attempts
.unwrap_or(HARD_CAP_RECONNECT_ATTEMPTS) .min(HARD_CAP_RECONNECT_ATTEMPTS);
let current_request = request.clone();
let internal_slot_sub_id = format!("internal-{}", uuid::Uuid::new_v4().to_string().split('-').next().unwrap());
let api_key_string = config.api_key.clone();
loop {
let mut attempt_request = current_request.clone();
if attempt_request.slots.is_empty() {
attempt_request.slots.insert(
internal_slot_sub_id.clone(),
SubscribeRequestFilterSlots::default()
);
}
if reconnect_attempts > 0 && tracked_slot > 0 {
attempt_request.from_slot = Some(tracked_slot);
}
match connect_and_subscribe_once(&config, attempt_request, api_key_string.clone()).await { Ok((sender, stream)) => {
let mut sender: Pin<Box<dyn futures_util::Sink<SubscribeRequest, Error = futures_mpsc::SendError> + Send>> = Box::pin(sender);
let mut stream: Pin<Box<dyn Stream<Item = Result<SubscribeUpdate, tonic::Status>> + Send>> =
Box::pin(stream.map_err(|ystatus| {
let code = tonic::Code::from_i32(ystatus.code() as i32);
tonic::Status::new(code, ystatus.message())
}));
while let Some(result) = stream.next().await {
match result {
Ok(update) => {
if matches!(&update.update_oneof, Some(UpdateOneof::Ping(_))) {
let pong_req = SubscribeRequest { ping: Some(SubscribeRequestPing { id: 1 }), ..Default::default() };
if let Err(e) = sender.send(pong_req).await {
warn!(error = %e, "Failed to send pong");
break;
}
continue;
}
if let Some(UpdateOneof::Slot(s)) = &update.update_oneof {
tracked_slot = s.slot;
}
yield update;
}
Err(status) => {
warn!(error = %status, "Stream error, attempting reconnection");
eprintln!("Stream error, attempting reconnection: {}", status);
break;
}
}
}
warn!("Stream ended, preparing to reconnect...");
eprintln!("Stream ended, preparing to reconnect...");
}
Err(err) => {
warn!(error = %err, "Failed to connect/subscribe, preparing to reconnect...");
eprintln!("Failed to connect/subscribe, preparing to reconnect: {}", err);
}
}
reconnect_attempts += 1;
if reconnect_attempts >= effective_max_attempts {
error!(attempts = effective_max_attempts, config_value = ?config.max_reconnect_attempts, "Max reconnection attempts reached");
Err(LaserstreamError::MaxReconnectAttempts(Status::cancelled(
format!("Max reconnection attempts ({}) reached", effective_max_attempts)
)))?;
}
if reconnect_attempts > 1 {
let delay = Duration::from_millis(FIXED_RECONNECT_INTERVAL_MS);
sleep(delay).await;
}
}
}
}
#[instrument(skip(config, request, api_key))]
async fn connect_and_subscribe_once(
config: &LaserstreamConfig,
request: SubscribeRequest,
api_key: String,
) -> Result<
(
impl futures_util::Sink<SubscribeRequest, Error = futures_mpsc::SendError> + Send,
impl Stream<Item = Result<SubscribeUpdate, yellowstone_grpc_proto::tonic::Status>> + Send,
),
tonic::Status,
> {
let mut builder = GeyserGrpcClient::build_from_shared(config.endpoint.clone()) .map_err(|e| tonic::Status::internal(format!("Build client error: {}", e)))?
.x_token(Some(api_key))
.map_err(|e| tonic::Status::internal(format!("Set token error: {}", e)))?
.connect_timeout(Duration::from_secs(10))
.max_decoding_message_size(1_000_000_000)
.timeout(Duration::from_secs(10))
.tls_config(ClientTlsConfig::new().with_enabled_roots())
.map_err(|e| tonic::Status::internal(format!("TLS config error: {}", e)))?
.connect()
.await
.map_err(|e| tonic::Status::internal(format!("Connect error: {}", e)))?;
let (sender, stream) = builder
.subscribe_with_request(Some(request))
.await
.map_err(|e| tonic::Status::internal(format!("Subscribe error: {}", e)))?;
Ok((sender, stream))
}