use std::{collections::HashSet, marker::PhantomData, sync::Arc, time::Instant};
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, de::DeserializeOwned};
use simulator_api::{
subscribe_config::{Compression, SubscribeConfig},
ws_compression::WsStreamDecompressor,
};
use solana_transaction_status::EncodedConfirmedTransactionWithStatusMeta;
use tokio::{
net::TcpStream,
sync::{mpsc, watch},
task::JoinHandle,
};
use tokio_tungstenite::{
MaybeTlsStream, WebSocketStream, connect_async,
tungstenite::{Message, client::IntoClientRequest},
};
use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};
use super::{
CONNECT_TIMEOUT, ConnectionStatus, HANDSHAKE_RESPONSE_TIMEOUT, KEEPALIVE_INTERVAL,
KEEPALIVE_MISS_DEADLINE, RECONNECT_UNGATED_ATTEMPTS, RECONNECT_UPTIME_RESET, ReconnectBudget,
ReconnectCoordinator, cancellable_sleep,
};
use crate::{error::err_chain, subscriptions::AccountDiffNotification, urls::http_to_ws_url};
pub struct SubscriptionHandle {
pub status: watch::Receiver<ConnectionStatus>,
pub notifications: mpsc::Receiver<SubscriptionNotification>,
pub join: JoinHandle<()>,
}
#[derive(Debug)]
pub enum SubscriptionNotification {
Transaction(Box<EncodedConfirmedTransactionWithStatusMeta>),
AccountDiff(AccountDiffNotification),
}
trait SubKind: Send + Sync + 'static {
type Notification: DeserializeOwned + Send + 'static;
const LABEL: &'static str;
const SUBSCRIBE_METHOD: &'static str;
const NOTIFICATION_METHOD: &'static str;
fn subscribe_params(program_id: &str) -> serde_json::Value;
fn into_notification(notification: Self::Notification) -> SubscriptionNotification;
fn slot_of(notification: &Self::Notification) -> u64;
}
struct AccountDiff;
impl SubKind for AccountDiff {
type Notification = AccountDiffNotification;
const LABEL: &'static str = "account-diff";
const SUBSCRIBE_METHOD: &'static str = "accountDiffSubscribe";
const NOTIFICATION_METHOD: &'static str = "accountDiffNotification";
fn subscribe_params(program_id: &str) -> serde_json::Value {
serde_json::json!([program_id, {"address_type": "program"}])
}
fn into_notification(notification: Self::Notification) -> SubscriptionNotification {
SubscriptionNotification::AccountDiff(notification)
}
fn slot_of(notification: &Self::Notification) -> u64 {
notification.context.slot
}
}
struct Transaction;
impl SubKind for Transaction {
type Notification = EncodedConfirmedTransactionWithStatusMeta;
const LABEL: &'static str = "transaction";
const SUBSCRIBE_METHOD: &'static str = "transactionSubscribe";
const NOTIFICATION_METHOD: &'static str = "transactionNotification";
fn subscribe_params(program_id: &str) -> serde_json::Value {
serde_json::json!([{"mentions": [program_id]}, {"commitment": "confirmed"}])
}
fn into_notification(notification: Self::Notification) -> SubscriptionNotification {
SubscriptionNotification::Transaction(Box::new(notification))
}
fn slot_of(notification: &Self::Notification) -> u64 {
notification.slot
}
}
pub fn spawn_transaction_subscription_manager(
rpc_endpoint: String,
program_ids: Vec<String>,
cancel: CancellationToken,
coordinator: Option<Arc<ReconnectCoordinator>>,
) -> SubscriptionHandle {
spawn_subscription_manager::<Transaction>(rpc_endpoint, program_ids, cancel, coordinator)
}
pub fn spawn_account_diff_subscription_manager(
rpc_endpoint: String,
program_ids: Vec<String>,
cancel: CancellationToken,
coordinator: Option<Arc<ReconnectCoordinator>>,
) -> SubscriptionHandle {
spawn_subscription_manager::<AccountDiff>(rpc_endpoint, program_ids, cancel, coordinator)
}
fn spawn_subscription_manager<K>(
rpc_endpoint: String,
program_ids: Vec<String>,
cancel: CancellationToken,
coordinator: Option<Arc<ReconnectCoordinator>>,
) -> SubscriptionHandle
where
K: SubKind,
{
let (notifications_tx, notifications_rx) = mpsc::channel(1024);
let (status_tx, status_rx) = watch::channel(ConnectionStatus::Down);
let task = Task::<K> {
rpc_endpoint,
program_ids,
notifications_tx,
status_tx,
cancel,
coordinator,
_marker: PhantomData,
};
let join = tokio::spawn(task.run());
SubscriptionHandle {
status: status_rx,
notifications: notifications_rx,
join,
}
}
type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
type Subs = HashSet<u64>;
struct Task<K: SubKind> {
rpc_endpoint: String,
program_ids: Vec<String>,
notifications_tx: mpsc::Sender<SubscriptionNotification>,
status_tx: watch::Sender<ConnectionStatus>,
cancel: CancellationToken,
coordinator: Option<Arc<ReconnectCoordinator>>,
_marker: PhantomData<fn() -> K>,
}
impl<K: SubKind> Task<K> {
async fn run(self) {
let mut budget = ReconnectBudget::new();
let mut replay_from_slot: Option<u64> = None;
loop {
if self.cancel.is_cancelled() {
break;
}
publish(&self.status_tx, ConnectionStatus::Down);
let reconnect_slot = match &self.coordinator {
Some(coord) if budget.attempt() >= RECONNECT_UNGATED_ATTEMPTS => {
let parked_at = Instant::now();
let Some(slot) = coord.reconnect_slot(&self.cancel).await else {
break; };
budget.discount_parked(parked_at.elapsed());
Some(slot)
}
_ => None,
};
let connect_result = tokio::select! {
biased;
_ = self.cancel.cancelled() => None,
result = async {
let ws = connect_ws(&self.rpc_endpoint).await?;
subscribe::<K>(ws, &self.program_ids, replay_from_slot).await
} => Some(result),
};
let Some(connect_result) = connect_result else {
break;
};
let Subscribed { ws, subs, pending } = match connect_result {
Ok(v) => v,
Err(why) => {
drop(reconnect_slot);
if retry_or_fail::<K>(
"connect",
why,
&mut budget,
&self.cancel,
&self.status_tx,
)
.await
{
continue;
}
break;
}
};
let streaming = self.coordinator.as_ref().map(|coord| coord.enter());
drop(reconnect_slot);
publish(&self.status_tx, ConnectionStatus::Up);
let connected_at = Instant::now();
let exit = message_loop::<K>(
ws,
subs,
pending,
&self.notifications_tx,
&self.cancel,
&mut replay_from_slot,
)
.await;
drop(streaming);
match exit {
MessageLoopExit::Cancelled | MessageLoopExit::Completed => break,
MessageLoopExit::ConnectionLost(why) => {
if connected_at.elapsed() >= RECONNECT_UPTIME_RESET {
budget.reset();
}
if retry_or_fail::<K>(
"connection lost",
why,
&mut budget,
&self.cancel,
&self.status_tx,
)
.await
{
continue;
}
break;
}
}
}
}
}
enum MessageLoopExit {
Cancelled,
ConnectionLost(String),
Completed,
}
async fn message_loop<K: SubKind>(
mut ws: Ws,
subs: Subs,
pending: Vec<Message>,
notifications_tx: &mpsc::Sender<SubscriptionNotification>,
cancel: &CancellationToken,
replay_from_slot: &mut Option<u64>,
) -> MessageLoopExit {
let mut ping_timer = tokio::time::interval(KEEPALIVE_INTERVAL);
ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
let mut last_inbound = Instant::now();
let mut completed: HashSet<u64> = HashSet::new();
let mut decompressor = match WsStreamDecompressor::new() {
Ok(decompressor) => decompressor,
Err(e) => return MessageLoopExit::ConnectionLost(format!("zstd decoder init: {e}")),
};
for msg in pending {
let outcome = process_data_frame::<K>(
msg,
&subs,
notifications_tx,
&mut completed,
replay_from_slot,
&mut decompressor,
)
.await;
if let Some(exit) = frame_outcome_to_exit(outcome) {
return exit;
}
}
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => return MessageLoopExit::Cancelled,
_ = ping_timer.tick() => {
if last_inbound.elapsed() > KEEPALIVE_MISS_DEADLINE {
return MessageLoopExit::ConnectionLost(format!(
"no traffic for {:?}", last_inbound.elapsed()
));
}
if let Err(e) = ws.send(Message::Ping(vec![])).await {
return MessageLoopExit::ConnectionLost(format!("ping send: {}", err_chain(&e)));
}
}
msg = ws.next() => {
last_inbound = Instant::now();
match msg {
Some(Ok(frame @ (Message::Text(_) | Message::Binary(_)))) => {
let outcome = process_data_frame::<K>(
frame,
&subs,
notifications_tx,
&mut completed,
replay_from_slot,
&mut decompressor,
)
.await;
if let Some(exit) = frame_outcome_to_exit(outcome) {
return exit;
}
}
Some(Ok(Message::Pong(_))) | Some(Ok(Message::Ping(_))) => {}
Some(Ok(Message::Close(frame))) => {
return MessageLoopExit::ConnectionLost(format!("remote close: {frame:?}"));
}
Some(Ok(Message::Frame(_))) => {}
Some(Err(e)) => return MessageLoopExit::ConnectionLost(format!("ws read: {}", err_chain(&e))),
None => return MessageLoopExit::ConnectionLost("ws stream ended".into()),
}
}
}
}
}
fn frame_outcome_to_exit(outcome: Result<TextOutcome, String>) -> Option<MessageLoopExit> {
match outcome {
Ok(TextOutcome::Continue) => None,
Ok(TextOutcome::AllComplete) => Some(MessageLoopExit::Completed),
Ok(TextOutcome::ChannelClosed) => Some(MessageLoopExit::Cancelled),
Err(why) => Some(MessageLoopExit::ConnectionLost(why)),
}
}
async fn process_data_frame<K: SubKind>(
msg: Message,
subs: &Subs,
notifications_tx: &mpsc::Sender<SubscriptionNotification>,
completed: &mut HashSet<u64>,
replay_from_slot: &mut Option<u64>,
decompressor: &mut WsStreamDecompressor,
) -> Result<TextOutcome, String> {
match msg {
Message::Text(t) => {
Ok(handle_text::<K>(&t, subs, notifications_tx, completed, replay_from_slot).await)
}
Message::Binary(b) => {
let decoded = decompressor
.decompress(&b)
.map_err(|e| format!("ws decompress: {e}"))?;
match std::str::from_utf8(&decoded) {
Ok(t) => {
Ok(
handle_text::<K>(t, subs, notifications_tx, completed, replay_from_slot)
.await,
)
}
Err(_) => Ok(TextOutcome::Continue),
}
}
_ => Ok(TextOutcome::Continue),
}
}
async fn retry_or_fail<K: SubKind>(
phase: &'static str,
reason: String,
budget: &mut ReconnectBudget,
cancel: &CancellationToken,
status_tx: &watch::Sender<ConnectionStatus>,
) -> bool {
if let Some(delay) = budget.next_backoff() {
warn!(
kind = K::LABEL,
attempt = budget.attempt(),
reason = %reason,
?delay,
"subscription {phase}, retrying",
);
cancellable_sleep(delay, cancel).await
} else {
publish(
status_tx,
ConnectionStatus::Failed(format!("{phase}: {reason}")),
);
false
}
}
fn publish(tx: &watch::Sender<ConnectionStatus>, status: ConnectionStatus) {
tx.send_if_modified(|current| {
if *current == status {
false
} else {
*current = status;
true
}
});
}
async fn connect_ws(rpc_endpoint: &str) -> Result<Ws, String> {
let ws_url = http_to_ws_url(rpc_endpoint).map_err(|e| err_chain(&e))?;
let request = ws_url
.into_client_request()
.map_err(|e| format!("build request: {}", err_chain(&e)))?;
let connect = tokio::time::timeout(CONNECT_TIMEOUT, connect_async(request))
.await
.map_err(|_| format!("connect timeout after {CONNECT_TIMEOUT:?}"))?
.map_err(|e| format!("connect: {}", err_chain(&e)))?;
Ok(connect.0)
}
struct Subscribed {
ws: Ws,
subs: Subs,
pending: Vec<Message>,
}
async fn subscribe<K: SubKind>(
mut ws: Ws,
program_ids: &[String],
replay_from_slot: Option<u64>,
) -> Result<Subscribed, String> {
let mut subs = Subs::new();
let mut pending = Vec::new();
for (i, program_id) in program_ids.iter().enumerate() {
let id = (i + 1) as u64;
let mut params = K::subscribe_params(program_id);
SubscribeConfig {
replay_from_slot: replay_from_slot.map(|slot| slot as i64),
compression: Some(Compression::Zstd),
}
.apply_to(&mut params);
let req = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"method": K::SUBSCRIBE_METHOD,
"params": params,
});
ws.send(Message::Text(req.to_string()))
.await
.map_err(|e| format!("subscribe send: {}", err_chain(&e)))?;
subs.insert(read_sub_ack(&mut ws, id, &mut pending).await?);
}
debug!(
kind = K::LABEL,
count = subs.len(),
"subscriptions established"
);
Ok(Subscribed { ws, subs, pending })
}
#[derive(Deserialize)]
struct SubAck {
id: u64,
result: Option<u64>,
#[serde(default)]
error: Option<serde_json::Value>,
}
async fn read_sub_ack(
ws: &mut Ws,
expected_id: u64,
pending: &mut Vec<Message>,
) -> Result<u64, String> {
let deadline = tokio::time::Instant::now() + HANDSHAKE_RESPONSE_TIMEOUT;
loop {
let msg = tokio::time::timeout_at(deadline, ws.next())
.await
.map_err(|_| format!("subscribe ack timeout after {HANDSHAKE_RESPONSE_TIMEOUT:?}"))?;
let Some(msg) = msg else {
return Err("ws ended during subscribe".into());
};
let msg = msg.map_err(|e| format!("ws read: {}", err_chain(&e)))?;
if let Message::Text(t) = &msg
&& let Ok(ack) = serde_json::from_str::<SubAck>(t)
{
if ack.id != expected_id {
continue;
}
if let Some(err) = ack.error {
return Err(format!("subscribe rejected: {err}"));
}
return ack
.result
.ok_or_else(|| "subscribe ack missing result".to_string());
}
if matches!(msg, Message::Text(_) | Message::Binary(_)) {
pending.push(msg);
}
}
}
enum TextOutcome {
Continue,
AllComplete,
ChannelClosed,
}
async fn handle_text<K: SubKind>(
text: &str,
subs: &Subs,
notifications_tx: &mpsc::Sender<SubscriptionNotification>,
completed: &mut HashSet<u64>,
replay_from_slot: &mut Option<u64>,
) -> TextOutcome {
if let Some(n) = parse_notification::<K>(text, subs) {
let slot = K::slot_of(&n);
if notifications_tx
.send(K::into_notification(n))
.await
.is_err()
{
return TextOutcome::ChannelClosed;
}
*replay_from_slot = Some(replay_from_slot.map_or(slot, |prev| prev.max(slot)));
return TextOutcome::Continue;
}
if let Some(sub_id) = parse_completion(text)
&& subs.contains(&sub_id)
{
completed.insert(sub_id);
if subs.iter().all(|id| completed.contains(id)) {
return TextOutcome::AllComplete;
}
}
TextOutcome::Continue
}
fn parse_completion(text: &str) -> Option<u64> {
#[derive(Deserialize)]
struct Msg {
method: String,
params: Params,
}
#[derive(Deserialize)]
struct Params {
subscription: u64,
}
let msg: Msg = serde_json::from_str(text).ok()?;
(msg.method == "subscriptionComplete").then_some(msg.params.subscription)
}
fn parse_notification<K: SubKind>(text: &str, subs: &Subs) -> Option<K::Notification> {
#[derive(Deserialize)]
#[serde(bound = "T: DeserializeOwned")]
struct Msg<T> {
method: String,
params: Params<T>,
}
#[derive(Deserialize)]
#[serde(bound = "T: DeserializeOwned")]
struct Params<T> {
subscription: u64,
result: T,
}
let msg: Msg<K::Notification> = serde_json::from_str(text).ok()?;
if msg.method != K::NOTIFICATION_METHOD {
return None;
}
if !subs.contains(&msg.params.subscription) {
return None;
}
Some(msg.params.result)
}
#[cfg(test)]
mod tests {
use super::parse_completion;
#[test]
fn parse_completion_extracts_subscription_id() {
let text =
r#"{"jsonrpc":"2.0","method":"subscriptionComplete","params":{"subscription":7}}"#;
assert_eq!(parse_completion(text), Some(7));
}
#[test]
fn parse_completion_ignores_other_messages() {
let notification = r#"{"jsonrpc":"2.0","method":"transactionNotification","params":{"subscription":7,"result":{}}}"#;
assert_eq!(parse_completion(notification), None);
assert_eq!(parse_completion("not json"), None);
}
}