use std::{future::Future, time::Duration};
use futures::{SinkExt, StreamExt};
use serde::Deserialize;
use solana_client::{
nonblocking::pubsub_client::PubsubClient,
rpc_response::{Response, RpcLogsResponse},
};
use solana_commitment_config::CommitmentConfig;
use solana_rpc_client_api::config::{RpcTransactionLogsConfig, RpcTransactionLogsFilter};
use thiserror::Error;
use tokio::{
sync::{oneshot, watch},
task::JoinHandle,
};
use tokio_tungstenite::tungstenite::Message;
use crate::urls::{UrlError, http_to_ws_url};
#[derive(Debug, Error)]
pub enum SubscriptionError {
#[error(transparent)]
InvalidUrl(#[from] UrlError),
#[error("pubsub connect to {url} failed: {source}")]
Connect {
url: String,
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
#[error("logs_subscribe failed: {source}")]
Subscribe {
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
#[error("subscription task exited unexpectedly before signaling ready")]
TaskDropped,
#[error("session has no rpc_endpoint (was the session created?)")]
NoRpcEndpoint,
}
#[derive(Debug, Error)]
pub enum SubscriptionRuntimeError {
#[error("{kind} subscription for {target} closed unexpectedly")]
Closed { kind: &'static str, target: String },
#[error("{kind} subscription callback worker for {target} failed: {source}")]
CallbackWorker {
kind: &'static str,
target: String,
#[source]
source: tokio::task::JoinError,
},
}
const SUBSCRIPTION_DRAIN_IDLE_TIMEOUT: Duration = Duration::from_millis(250);
const SUBSCRIPTION_DRAIN_MAX_DURATION: Duration = Duration::from_secs(5);
type SubscriptionTaskHandle = JoinHandle<Result<(), SubscriptionRuntimeError>>;
type AccountDiffWs =
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
pub struct SubscriptionHandle {
pub join_handle: SubscriptionTaskHandle,
pub stop: watch::Sender<bool>,
}
impl From<LogSubscriptionHandle> for SubscriptionHandle {
fn from(h: LogSubscriptionHandle) -> Self {
Self {
join_handle: h.join_handle,
stop: h.stop,
}
}
}
impl From<AccountDiffSubscriptionHandle> for SubscriptionHandle {
fn from(h: AccountDiffSubscriptionHandle) -> Self {
Self {
join_handle: h.join_handle,
stop: h.stop,
}
}
}
pub struct LogSubscriptionHandle {
pub join_handle: SubscriptionTaskHandle,
pub stop: watch::Sender<bool>,
}
pub async fn subscribe_program_logs<F, Fut>(
rpc_endpoint: &str,
program_id: &str,
commitment: CommitmentConfig,
on_notification: F,
) -> Result<LogSubscriptionHandle, SubscriptionError>
where
F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let ws_url = http_to_ws_url(rpc_endpoint)?;
let program_id = program_id.to_string();
let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
let (stop_tx, mut stop_rx) = watch::channel(false);
let join_handle = tokio::spawn(async move {
let client = match PubsubClient::new(&ws_url).await {
Ok(c) => c,
Err(e) => {
let _ = ready_tx.send(Err(SubscriptionError::Connect {
url: ws_url,
source: Box::new(e),
}));
return Ok(());
}
};
let (mut stream, _unsubscribe) = match client
.logs_subscribe(
RpcTransactionLogsFilter::Mentions(vec![program_id.clone()]),
RpcTransactionLogsConfig {
commitment: Some(commitment),
},
)
.await
{
Ok(s) => s,
Err(e) => {
let _ = ready_tx.send(Err(SubscriptionError::Subscribe {
source: Box::new(e),
}));
return Ok(());
}
};
let _ = ready_tx.send(Ok(()));
let mut tasks: Vec<JoinHandle<()>> = Vec::new();
let kind = "program logs";
loop {
if *stop_rx.borrow() {
let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
while let Ok(Ok(Some(notification))) = tokio::time::timeout_at(
drain_deadline,
tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, stream.next()),
)
.await
{
tasks.push(tokio::spawn(on_notification(notification)));
}
break;
}
let notification = tokio::select! {
n = stream.next() => n,
_ = stop_rx.changed() => continue,
};
match notification {
Some(n) => tasks.push(tokio::spawn(on_notification(n))),
None => return Err(subscription_runtime_closed(kind, &program_id)),
}
}
for task in tasks {
if let Err(source) = task.await {
return Err(callback_worker_failed(kind, &program_id, source));
}
}
Ok(())
});
match ready_rx.await {
Ok(Ok(())) => Ok(LogSubscriptionHandle {
join_handle,
stop: stop_tx,
}),
Ok(Err(e)) => {
join_handle.abort();
Err(e)
}
Err(_) => {
join_handle.abort();
Err(SubscriptionError::TaskDropped)
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct AccountDiffContext {
pub slot: u64,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AccountDiffNotification {
pub context: AccountDiffContext,
pub account: Option<String>,
pub signature: Option<String>,
pub pre: Option<serde_json::Value>,
pub post: Option<serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct RoutedAccountDiffNotification {
pub account: String,
pub notification: AccountDiffNotification,
}
pub struct AccountDiffSubscriptionHandle {
pub join_handle: SubscriptionTaskHandle,
pub stop: watch::Sender<bool>,
}
fn subscription_runtime_closed(
kind: &'static str,
target: impl Into<String>,
) -> SubscriptionRuntimeError {
SubscriptionRuntimeError::Closed {
kind,
target: target.into(),
}
}
fn callback_worker_failed(
kind: &'static str,
target: impl Into<String>,
source: tokio::task::JoinError,
) -> SubscriptionRuntimeError {
SubscriptionRuntimeError::CallbackWorker {
kind,
target: target.into(),
source,
}
}
pub async fn subscribe_account_diffs<F, Fut>(
rpc_endpoint: &str,
account: &str,
on_notification: F,
) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
where
F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
subscribe_account_diffs_many(rpc_endpoint, [account.to_string()], move |notification| {
on_notification(notification.notification)
})
.await
}
pub async fn subscribe_account_diffs_many<F, Fut, I, S>(
rpc_endpoint: &str,
accounts: I,
on_notification: F,
) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
where
F: Fn(RoutedAccountDiffNotification) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
I: IntoIterator<Item = S>,
S: Into<String>,
{
let ws_url = http_to_ws_url(rpc_endpoint)?;
let accounts = dedup_accounts(accounts);
if accounts.is_empty() {
let (stop_tx, stop_rx) = watch::channel(false);
return Ok(AccountDiffSubscriptionHandle {
join_handle: tokio::spawn(async move {
let _ = stop_rx;
Ok(())
}),
stop: stop_tx,
});
}
let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
let (stop_tx, mut stop_rx) = watch::channel(false);
let target = format!("{} accounts", accounts.len());
let join_handle = tokio::spawn(async move {
let (notification_tx, mut notification_rx) = tokio::sync::mpsc::unbounded_channel();
let callback_handle = tokio::spawn(async move {
while let Some(notification) = notification_rx.recv().await {
on_notification(notification).await;
}
});
let (mut ws, _) = match tokio_tungstenite::connect_async(&ws_url).await {
Ok(connection) => connection,
Err(e) => {
let _ = ready_tx.send(Err(SubscriptionError::Connect {
url: ws_url,
source: Box::new(e),
}));
return Ok(());
}
};
let subscriptions =
match send_account_diff_subscribe_many(&mut ws, &accounts, ¬ification_tx).await {
Ok(subscriptions) => subscriptions,
Err(error) => {
let _ = ready_tx.send(Err(error));
return Ok(());
}
};
let _ = ready_tx.send(Ok(()));
if let Err(error) =
drive_account_diff_stream_many(&mut ws, &subscriptions, ¬ification_tx, &mut stop_rx)
.await
{
drop(notification_tx);
if let Err(source) = callback_handle.await {
return Err(callback_worker_failed("account diff", target, source));
}
return Err(error);
}
drop(notification_tx);
if let Err(source) = callback_handle.await {
return Err(callback_worker_failed("account diff", target, source));
}
Ok(())
});
match ready_rx.await {
Ok(Ok(())) => Ok(AccountDiffSubscriptionHandle {
join_handle,
stop: stop_tx,
}),
Ok(Err(e)) => {
join_handle.abort();
Err(e)
}
Err(_) => {
join_handle.abort();
Err(SubscriptionError::TaskDropped)
}
}
}
#[derive(Deserialize)]
struct AccountDiffMessage {
method: String,
params: AccountDiffParams,
}
#[derive(Deserialize)]
struct AccountDiffParams {
subscription: u64,
result: AccountDiffNotification,
}
async fn send_account_diff_subscribe_many(
ws: &mut AccountDiffWs,
accounts: &[String],
notification_tx: &tokio::sync::mpsc::UnboundedSender<RoutedAccountDiffNotification>,
) -> Result<std::collections::HashMap<u64, String>, SubscriptionError> {
#[derive(Deserialize)]
struct SubscriptionConfirmation {
id: u64,
result: Option<u64>,
}
let mut pending: std::collections::HashMap<u64, String> = std::collections::HashMap::new();
let mut subscriptions = std::collections::HashMap::with_capacity(accounts.len());
for (index, account) in accounts.iter().enumerate() {
let request_id = (index + 1) as u64;
let req = serde_json::json!({
"jsonrpc": "2.0",
"id": request_id,
"method": "accountDiffSubscribe",
"params": [account]
});
ws.send(Message::Text(req.to_string()))
.await
.map_err(|source| SubscriptionError::Subscribe {
source: Box::new(source),
})?;
pending.insert(request_id, account.clone());
}
while !pending.is_empty() {
match ws.next().await {
Some(Ok(Message::Text(text))) => {
if let Ok(confirmation) = serde_json::from_str::<SubscriptionConfirmation>(&text) {
let Some(account) = pending.remove(&confirmation.id) else {
continue;
};
let Some(subscription_id) = confirmation.result else {
return Err(SubscriptionError::TaskDropped);
};
subscriptions.insert(subscription_id, account);
continue;
}
if let Some(notification) =
parse_routed_account_diff_notification(&text, &subscriptions)
{
let _ = notification_tx.send(notification);
}
}
Some(Ok(_)) => {}
_ => return Err(SubscriptionError::TaskDropped),
}
}
Ok(subscriptions)
}
async fn drive_account_diff_stream_many(
ws: &mut AccountDiffWs,
subscriptions: &std::collections::HashMap<u64, String>,
notification_tx: &tokio::sync::mpsc::UnboundedSender<RoutedAccountDiffNotification>,
stop_rx: &mut watch::Receiver<bool>,
) -> Result<(), SubscriptionRuntimeError> {
loop {
if *stop_rx.borrow() {
let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
loop {
match tokio::time::timeout_at(
drain_deadline,
tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, ws.next()),
)
.await
{
Ok(Ok(Some(Ok(Message::Text(text))))) => {
if let Some(notification) =
parse_routed_account_diff_notification(&text, subscriptions)
{
let _ = notification_tx.send(notification);
}
}
_ => return Ok(()),
}
}
}
let msg = tokio::select! {
m = ws.next() => m,
_ = stop_rx.changed() => continue,
};
match msg {
Some(Ok(Message::Text(text))) => {
if let Some(notification) =
parse_routed_account_diff_notification(&text, subscriptions)
{
let _ = notification_tx.send(notification);
}
}
Some(Ok(_)) => {}
_ => {
return Err(subscription_runtime_closed(
"account diff",
format!("{} accounts", subscriptions.len()),
));
}
}
}
}
fn parse_account_diff_message(text: &str) -> Option<AccountDiffMessage> {
let msg: AccountDiffMessage = serde_json::from_str(text).ok()?;
(msg.method == "accountDiffNotification").then_some(msg)
}
fn parse_routed_account_diff_notification(
text: &str,
subscriptions: &std::collections::HashMap<u64, String>,
) -> Option<RoutedAccountDiffNotification> {
let msg = parse_account_diff_message(text)?;
let account = subscriptions.get(&msg.params.subscription)?.clone();
Some(RoutedAccountDiffNotification {
account,
notification: msg.params.result,
})
}
fn dedup_accounts<I, S>(accounts: I) -> Vec<String>
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
let mut unique = std::collections::BTreeSet::new();
accounts
.into_iter()
.map(Into::into)
.filter(|account| unique.insert(account.clone()))
.collect()
}
pub async fn subscribe_program_diffs<F, Fut>(
rpc_endpoint: &str,
program_id: &str,
on_notification: F,
) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
where
F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let ws_url = http_to_ws_url(rpc_endpoint)?;
let program_id = program_id.to_string();
let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
let (stop_tx, mut stop_rx) = watch::channel(false);
let join_handle = tokio::spawn(async move {
let (notification_tx, mut notification_rx) = tokio::sync::mpsc::unbounded_channel();
let callback_handle = tokio::spawn(async move {
while let Some(notification) = notification_rx.recv().await {
on_notification(notification).await;
}
});
let (mut ws, _) = match tokio_tungstenite::connect_async(&ws_url).await {
Ok(connection) => connection,
Err(e) => {
let _ = ready_tx.send(Err(SubscriptionError::Connect {
url: ws_url,
source: Box::new(e),
}));
return Ok(());
}
};
if let Err(error) = send_program_diff_subscribe(&mut ws, &program_id).await {
let _ = ready_tx.send(Err(error));
return Ok(());
}
let _ = ready_tx.send(Ok(()));
if let Err(error) =
drive_program_diff_stream(&mut ws, ¬ification_tx, &mut stop_rx, &program_id).await
{
drop(notification_tx);
if let Err(source) = callback_handle.await {
return Err(callback_worker_failed(
"program account diff",
&program_id,
source,
));
}
return Err(error);
}
drop(notification_tx);
if let Err(source) = callback_handle.await {
return Err(callback_worker_failed(
"program account diff",
&program_id,
source,
));
}
Ok(())
});
match ready_rx.await {
Ok(Ok(())) => Ok(AccountDiffSubscriptionHandle {
join_handle,
stop: stop_tx,
}),
Ok(Err(e)) => {
join_handle.abort();
Err(e)
}
Err(_) => {
join_handle.abort();
Err(SubscriptionError::TaskDropped)
}
}
}
async fn send_program_diff_subscribe(
ws: &mut AccountDiffWs,
program_id: &str,
) -> Result<(), SubscriptionError> {
#[derive(Deserialize)]
struct SubscriptionConfirmation {
result: Option<u64>,
}
let req = serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "accountDiffSubscribe",
"params": [program_id, {"address_type": "program"}]
});
ws.send(Message::Text(req.to_string()))
.await
.map_err(|source| SubscriptionError::Subscribe {
source: Box::new(source),
})?;
loop {
match ws.next().await {
Some(Ok(Message::Text(text))) => {
match serde_json::from_str::<SubscriptionConfirmation>(&text) {
Ok(SubscriptionConfirmation { result: Some(_) }) => return Ok(()),
Ok(_) => continue,
Err(source) => {
return Err(SubscriptionError::Subscribe {
source: Box::new(source),
});
}
}
}
Some(Ok(_)) => continue,
_ => return Err(SubscriptionError::TaskDropped),
}
}
}
async fn drive_program_diff_stream(
ws: &mut AccountDiffWs,
notification_tx: &tokio::sync::mpsc::UnboundedSender<AccountDiffNotification>,
stop_rx: &mut watch::Receiver<bool>,
program_id: &str,
) -> Result<(), SubscriptionRuntimeError> {
loop {
if *stop_rx.borrow() {
let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
loop {
match tokio::time::timeout_at(
drain_deadline,
tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, ws.next()),
)
.await
{
Ok(Ok(Some(Ok(Message::Text(text))))) => {
if let Some(msg) = parse_account_diff_message(&text) {
let _ = notification_tx.send(msg.params.result);
}
}
_ => return Ok(()),
}
}
}
let msg = tokio::select! {
m = ws.next() => m,
_ = stop_rx.changed() => continue,
};
match msg {
Some(Ok(Message::Text(text))) => {
if let Some(msg) = parse_account_diff_message(&text) {
let _ = notification_tx.send(msg.params.result);
}
}
Some(Ok(_)) => {}
_ => {
return Err(subscription_runtime_closed(
"program account diff",
program_id,
));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_account_diff_notification_ignores_other_messages() {
let confirmation = r#"{"jsonrpc":"2.0","result":1,"id":1}"#;
assert!(parse_account_diff_message(confirmation).is_none());
}
#[test]
fn parse_account_diff_notification_extracts_payload() {
let text = r#"{
"jsonrpc":"2.0",
"method":"accountDiffNotification",
"params":{
"subscription":7,
"result":{
"context":{"slot":123},
"signature":"sig",
"pre":{"a":1},
"post":{"a":2}
}
}
}"#;
let notification = parse_account_diff_message(text)
.expect("notification")
.params
.result;
assert_eq!(notification.context.slot, 123);
assert_eq!(notification.signature.as_deref(), Some("sig"));
assert_eq!(notification.pre, Some(serde_json::json!({"a": 1})));
assert_eq!(notification.post, Some(serde_json::json!({"a": 2})));
}
#[test]
fn parse_routed_account_diff_notification_extracts_subscription_account() {
let text = r#"{
"jsonrpc":"2.0",
"method":"accountDiffNotification",
"params":{
"subscription":42,
"result":{
"context":{"slot":456},
"signature":"sig",
"pre":null,
"post":{"a":2}
}
}
}"#;
let subscriptions = std::collections::HashMap::from([(42_u64, "acct".to_string())]);
let notification =
parse_routed_account_diff_notification(text, &subscriptions).expect("notification");
assert_eq!(notification.account, "acct");
assert_eq!(notification.notification.context.slot, 456);
}
#[test]
fn dedup_accounts_preserves_first_seen_order() {
let accounts = dedup_accounts([
"b".to_string(),
"a".to_string(),
"b".to_string(),
"c".to_string(),
]);
assert_eq!(accounts, vec!["b", "a", "c"]);
}
}