use crate::{LaserstreamConfig, LaserstreamError, config::CompressionEncoding as ConfigCompressionEncoding};
use async_stream::stream;
use futures::StreamExt;
use futures_channel::mpsc as futures_mpsc;
use futures_util::{sink::SinkExt, Stream};
use std::{pin::Pin, time::Duration};
use tokio::sync::mpsc;
use tokio::time::sleep;
use yellowstone_grpc_proto::tonic::{
Status, Request, metadata::MetadataValue, transport::Endpoint, codec::CompressionEncoding,
};
use tracing::{error, instrument, warn};
use uuid;
use yellowstone_grpc_client::{ClientTlsConfig, Interceptor};
use yellowstone_grpc_proto::prelude::{geyser_client::GeyserClient};
use yellowstone_grpc_proto::geyser::{
subscribe_update::UpdateOneof, SubscribeRequest, SubscribeRequestFilterSlots,
SubscribeRequestPing, SubscribeUpdate,
};
const HARD_CAP_RECONNECT_ATTEMPTS: u32 = (20 * 60) / 5; const FIXED_RECONNECT_INTERVAL_MS: u64 = 5000; const SDK_NAME: &str = "laserstream-rust";
const SDK_VERSION: &str = "0.1.3";
#[derive(Clone)]
struct SdkMetadataInterceptor {
x_token: Option<yellowstone_grpc_proto::tonic::metadata::AsciiMetadataValue>,
}
impl SdkMetadataInterceptor {
fn new(api_key: String) -> Result<Self, Status> {
let x_token = if !api_key.is_empty() {
Some(api_key.parse().map_err(|e| {
Status::invalid_argument(format!("Invalid API key: {}", e))
})?)
} else {
None
};
Ok(Self { x_token })
}
}
impl Interceptor for SdkMetadataInterceptor {
fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
if let Some(ref x_token) = self.x_token {
request.metadata_mut().insert("x-token", x_token.clone());
}
request.metadata_mut().insert("x-sdk-name", MetadataValue::from_static(SDK_NAME));
request.metadata_mut().insert("x-sdk-version", MetadataValue::from_static(SDK_VERSION));
Ok(request)
}
}
#[derive(Clone)]
pub struct StreamHandle {
write_tx: mpsc::UnboundedSender<SubscribeRequest>,
}
impl StreamHandle {
pub async fn write(&self, request: SubscribeRequest) -> Result<(), LaserstreamError> {
self.write_tx
.send(request)
.map_err(|_| LaserstreamError::ConnectionError("Write channel closed".to_string()))
}
}
#[instrument(skip(config, request))]
pub fn subscribe(
config: LaserstreamConfig,
request: SubscribeRequest,
) -> (
impl Stream<Item = Result<SubscribeUpdate, LaserstreamError>>,
StreamHandle,
) {
let (write_tx, mut write_rx) = mpsc::unbounded_channel::<SubscribeRequest>();
let handle = StreamHandle { write_tx };
let update_stream = stream! {
let mut reconnect_attempts = 0;
let mut tracked_slot: u64 = 0;
let effective_max_attempts = config
.max_reconnect_attempts
.unwrap_or(HARD_CAP_RECONNECT_ATTEMPTS) .min(HARD_CAP_RECONNECT_ATTEMPTS);
let mut current_request = request.clone();
let internal_slot_sub_id = format!("internal-{}", uuid::Uuid::new_v4().to_string().split('-').next().unwrap());
let replay_enabled = config.replay;
if replay_enabled {
current_request.slots.insert(
internal_slot_sub_id.clone(),
SubscribeRequestFilterSlots {
filter_by_commitment: Some(true), ..Default::default()
}
);
}
if !replay_enabled {
current_request.from_slot = None;
}
let api_key_string = config.api_key.clone();
loop {
let mut attempt_request = current_request.clone();
if reconnect_attempts > 0 && tracked_slot > 0 && replay_enabled {
let commitment_level = attempt_request.commitment.unwrap_or(0);
let from_slot = match commitment_level {
0 => tracked_slot.saturating_sub(31), 1 | 2 => tracked_slot, _ => tracked_slot.saturating_sub(31), };
attempt_request.from_slot = Some(from_slot);
} else if !replay_enabled {
attempt_request.from_slot = None;
}
match connect_and_subscribe_once(&config, attempt_request, api_key_string.clone()).await {
Ok((sender, stream)) => {
reconnect_attempts = 0;
let mut sender: Pin<Box<dyn futures_util::Sink<SubscribeRequest, Error = futures_mpsc::SendError> + Send>> = Box::pin(sender);
let mut stream: Pin<Box<dyn Stream<Item = Result<SubscribeUpdate, Status>> + Send>> = Box::pin(stream);
let mut ping_interval = tokio::time::interval(Duration::from_secs(30));
ping_interval.tick().await; let mut ping_id = 0i32;
loop {
tokio::select! {
_ = ping_interval.tick() => {
ping_id = ping_id.wrapping_add(1);
let ping_request = SubscribeRequest {
ping: Some(SubscribeRequestPing { id: ping_id }),
..Default::default()
};
let _ = sender.send(ping_request).await;
},
result = stream.next() => {
if let Some(result) = result {
match result {
Ok(update) => {
if matches!(&update.update_oneof, Some(UpdateOneof::Ping(_))) {
let pong_req = SubscribeRequest { ping: Some(SubscribeRequestPing { id: 1 }), ..Default::default() };
if let Err(e) = sender.send(pong_req).await {
warn!(error = %e, "Failed to send pong");
break;
}
continue;
}
if matches!(&update.update_oneof, Some(UpdateOneof::Pong(_))) {
continue;
}
if let Some(UpdateOneof::Slot(s)) = &update.update_oneof {
if replay_enabled {
tracked_slot = s.slot;
}
if update.filters.len() == 1 && update.filters.contains(&internal_slot_sub_id) {
continue;
}
}
let mut clean_update = update;
if replay_enabled {
clean_update.filters.retain(|f| f != &internal_slot_sub_id);
if !clean_update.filters.is_empty() {
yield Ok(clean_update);
}
} else {
yield Ok(clean_update);
}
}
Err(status) => {
warn!(error = %status, "Stream error, will reconnect after 5s delay");
yield Err(LaserstreamError::Status(status.clone()));
break;
}
}
} else {
break;
}
}
Some(write_request) = write_rx.recv() => {
if let Err(e) = sender.send(write_request).await {
warn!(error = %e, "Failed to send write request");
break;
}
}
}
}
}
Err(err) => {
error!(error = %err, "Connection failed, will retry after 5s delay");
yield Err(LaserstreamError::Status(err));
}
}
reconnect_attempts += 1;
if reconnect_attempts >= effective_max_attempts {
error!(attempts = effective_max_attempts, "Max reconnection attempts reached");
yield Err(LaserstreamError::MaxReconnectAttempts(Status::cancelled(
format!("Max reconnection attempts ({}) reached", effective_max_attempts)
)));
return;
}
let delay = Duration::from_millis(FIXED_RECONNECT_INTERVAL_MS);
sleep(delay).await;
}
};
(update_stream, handle)
}
#[instrument(skip(config, request, api_key))]
async fn connect_and_subscribe_once(
config: &LaserstreamConfig,
request: SubscribeRequest,
api_key: String,
) -> Result<
(
impl futures_util::Sink<SubscribeRequest, Error = futures_mpsc::SendError> + Send,
impl Stream<Item = Result<SubscribeUpdate, yellowstone_grpc_proto::tonic::Status>> + Send,
),
Status,
> {
let options = &config.channel_options;
let interceptor = SdkMetadataInterceptor::new(api_key)?;
let mut endpoint = Endpoint::from_shared(config.endpoint.clone())
.map_err(|e| Status::internal(format!("Failed to parse endpoint: {}", e)))?
.connect_timeout(Duration::from_secs(options.connect_timeout_secs.unwrap_or(10)))
.timeout(Duration::from_secs(options.timeout_secs.unwrap_or(30)))
.http2_keep_alive_interval(Duration::from_secs(options.http2_keep_alive_interval_secs.unwrap_or(30)))
.keep_alive_timeout(Duration::from_secs(options.keep_alive_timeout_secs.unwrap_or(5)))
.keep_alive_while_idle(options.keep_alive_while_idle.unwrap_or(true))
.initial_stream_window_size(options.initial_stream_window_size.or(Some(1024 * 1024 * 4)))
.initial_connection_window_size(options.initial_connection_window_size.or(Some(1024 * 1024 * 8)))
.http2_adaptive_window(options.http2_adaptive_window.unwrap_or(true))
.tcp_nodelay(options.tcp_nodelay.unwrap_or(true))
.buffer_size(options.buffer_size.or(Some(1024 * 64)));
if let Some(tcp_keepalive_secs) = options.tcp_keepalive_secs {
endpoint = endpoint.tcp_keepalive(Some(Duration::from_secs(tcp_keepalive_secs)));
}
endpoint = endpoint
.tls_config(ClientTlsConfig::new().with_enabled_roots())
.map_err(|e| Status::internal(format!("TLS config error: {}", e)))?;
let channel = endpoint
.connect()
.await
.map_err(|e| Status::unavailable(format!("Connection failed: {}", e)))?;
let mut geyser_client = GeyserClient::with_interceptor(channel, interceptor);
geyser_client = geyser_client
.max_decoding_message_size(options.max_decoding_message_size.unwrap_or(1_000_000_000))
.max_encoding_message_size(options.max_encoding_message_size.unwrap_or(32_000_000));
if let Some(send_comp) = options.send_compression {
let encoding = match send_comp {
ConfigCompressionEncoding::Gzip => CompressionEncoding::Gzip,
ConfigCompressionEncoding::Zstd => CompressionEncoding::Zstd,
};
geyser_client = geyser_client.send_compressed(encoding);
}
if let Some(ref accept_comps) = options.accept_compression {
for comp in accept_comps {
let encoding = match comp {
ConfigCompressionEncoding::Gzip => CompressionEncoding::Gzip,
ConfigCompressionEncoding::Zstd => CompressionEncoding::Zstd,
};
geyser_client = geyser_client.accept_compressed(encoding);
}
} else {
geyser_client = geyser_client
.accept_compressed(CompressionEncoding::Gzip)
.accept_compressed(CompressionEncoding::Zstd);
}
let (mut subscribe_tx, subscribe_rx) = futures_mpsc::unbounded();
subscribe_tx
.send(request)
.await
.map_err(|e| Status::internal(format!("Failed to send initial request: {}", e)))?;
let response = geyser_client
.subscribe(subscribe_rx)
.await
.map_err(|e| Status::internal(format!("Subscription failed: {}", e)))?;
Ok((subscribe_tx, response.into_inner()))
}