sim-cli 0.3.0

CLI tool for running and comparing Solana simulator backtests
//! SubscriptionManager — owns subscription WebSockets and keeps them alive.
//!
//! Two variants:
//! - account-diff subscription (`accountDiffSubscribe`) — used by sim-cli for
//!   account state capture
//! - program-log subscription (`logsSubscribe`) — used by sim-cli for program
//!   log capture
//!
//! Both follow the same reconnect + keepalive pattern. On reconnect, all
//! configured subscriptions are re-established from scratch; we do not attempt
//! to replay notifications missed during the gap.

use std::{
    collections::HashSet, future::Future, marker::PhantomData, pin::Pin, sync::Arc, time::Instant,
};

use futures::{SinkExt, StreamExt};
use serde::{Deserialize, de::DeserializeOwned};
use simulator_client::{AccountDiffNotification, urls::http_to_ws_url};
use solana_client::rpc_response::{Response, RpcLogsResponse};
use tokio::{
    net::TcpStream,
    sync::watch,
    task::{JoinHandle, JoinSet},
};
use tokio_tungstenite::{
    MaybeTlsStream, WebSocketStream, connect_async,
    tungstenite::{Message, client::IntoClientRequest},
};
use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};

use super::{
    CONNECT_TIMEOUT, ConnectionStatus, HANDSHAKE_RESPONSE_TIMEOUT, KEEPALIVE_INTERVAL,
    KEEPALIVE_MISS_DEADLINE, RECONNECT_UPTIME_RESET, ReconnectBudget, cancellable_sleep,
};

/// Handle to a running subscription manager task.
pub struct SubscriptionHandle {
    pub status: watch::Receiver<ConnectionStatus>,
    pub join: JoinHandle<()>,
}

/// Per-flavor differences between `accountDiffSubscribe` and `logsSubscribe`.
pub trait SubKind: Send + Sync + 'static {
    type Notification: DeserializeOwned + Send + 'static;
    const LABEL: &'static str;
    const SUBSCRIBE_METHOD: &'static str;
    const NOTIFICATION_METHOD: &'static str;
    fn subscribe_params(program_id: &str) -> serde_json::Value;
}

pub struct AccountDiff;
impl SubKind for AccountDiff {
    type Notification = AccountDiffNotification;
    const LABEL: &'static str = "account-diff";
    const SUBSCRIBE_METHOD: &'static str = "accountDiffSubscribe";
    const NOTIFICATION_METHOD: &'static str = "accountDiffNotification";
    fn subscribe_params(program_id: &str) -> serde_json::Value {
        serde_json::json!([program_id, {"address_type": "program"}])
    }
}

pub struct ProgramLog;
impl SubKind for ProgramLog {
    type Notification = Response<RpcLogsResponse>;
    const LABEL: &'static str = "program-log";
    const SUBSCRIBE_METHOD: &'static str = "logsSubscribe";
    const NOTIFICATION_METHOD: &'static str = "logsNotification";
    fn subscribe_params(program_id: &str) -> serde_json::Value {
        serde_json::json!([{"mentions": [program_id]}, {"commitment": "confirmed"}])
    }
}

type Callback<N> = Arc<dyn Fn(N) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;

pub fn spawn_subscription_manager<K, F, Fut>(
    rpc_endpoint: String,
    program_ids: Vec<String>,
    on_notification: F,
    cancel: CancellationToken,
) -> SubscriptionHandle
where
    K: SubKind,
    F: Fn(K::Notification) -> Fut + Send + Sync + 'static,
    Fut: Future<Output = ()> + Send + 'static,
{
    let callback: Callback<K::Notification> = Arc::new(move |n| Box::pin(on_notification(n)));
    let (status_tx, status_rx) = watch::channel(ConnectionStatus::Down);
    let task = Task::<K> {
        rpc_endpoint,
        program_ids,
        callback,
        status_tx,
        cancel,
        _marker: PhantomData,
    };
    let join = tokio::spawn(task.run());
    SubscriptionHandle {
        status: status_rx,
        join,
    }
}

type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
type Subs = HashSet<u64>;

struct Task<K: SubKind> {
    rpc_endpoint: String,
    program_ids: Vec<String>,
    callback: Callback<K::Notification>,
    status_tx: watch::Sender<ConnectionStatus>,
    cancel: CancellationToken,
    _marker: PhantomData<fn() -> K>,
}

impl<K: SubKind> Task<K> {
    async fn run(self) {
        let mut tasks: JoinSet<()> = JoinSet::new();
        let mut budget = ReconnectBudget::new();

        loop {
            if self.cancel.is_cancelled() {
                break;
            }
            publish(&self.status_tx, ConnectionStatus::Down);

            let connect_result = async {
                let ws = connect_ws(&self.rpc_endpoint).await?;
                subscribe::<K>(ws, &self.program_ids).await
            }
            .await;

            let (ws, subs) = match connect_result {
                Ok(v) => v,
                Err(why) => {
                    if retry_or_fail::<K>(
                        "connect",
                        why,
                        &mut budget,
                        &self.cancel,
                        &self.status_tx,
                    )
                    .await
                    {
                        continue;
                    }
                    break;
                }
            };

            publish(&self.status_tx, ConnectionStatus::Up);
            let connected_at = Instant::now();

            let exit =
                message_loop::<K>(ws, subs, self.callback.clone(), &self.cancel, &mut tasks).await;

            match exit {
                MessageLoopExit::Cancelled => break,
                MessageLoopExit::ConnectionLost(why) => {
                    if connected_at.elapsed() >= RECONNECT_UPTIME_RESET {
                        budget.reset();
                    }
                    if retry_or_fail::<K>(
                        "connection lost",
                        why,
                        &mut budget,
                        &self.cancel,
                        &self.status_tx,
                    )
                    .await
                    {
                        continue;
                    }
                    break;
                }
            }
        }

        // Drain outstanding callbacks so the caller's `sub.join.await` actually
        // waits for in-flight work to finish (and so any callback panic is
        // surfaced rather than silently dropped).
        while let Some(res) = tasks.join_next().await {
            if let Err(e) = res
                && !e.is_cancelled()
            {
                warn!(kind = K::LABEL, error = %e, "subscription callback panicked");
            }
        }
    }
}

enum MessageLoopExit {
    Cancelled,
    ConnectionLost(String),
}

async fn message_loop<K: SubKind>(
    mut ws: Ws,
    subs: Subs,
    callback: Callback<K::Notification>,
    cancel: &CancellationToken,
    tasks: &mut JoinSet<()>,
) -> MessageLoopExit {
    let mut ping_timer = tokio::time::interval(KEEPALIVE_INTERVAL);
    ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
    let mut last_inbound = Instant::now();

    loop {
        tokio::select! {
            biased;
            _ = cancel.cancelled() => return MessageLoopExit::Cancelled,

            _ = ping_timer.tick() => {
                if last_inbound.elapsed() > KEEPALIVE_MISS_DEADLINE {
                    return MessageLoopExit::ConnectionLost(format!(
                        "no traffic for {:?}", last_inbound.elapsed()
                    ));
                }
                if let Err(e) = ws.send(Message::Ping(vec![])).await {
                    return MessageLoopExit::ConnectionLost(format!("ping send: {e}"));
                }
            }

            msg = ws.next() => {
                last_inbound = Instant::now();
                match msg {
                    Some(Ok(Message::Text(t))) => {
                        if let Some(n) = parse_notification::<K>(&t, &subs) {
                            tasks.spawn(callback(n));
                        }
                    }
                    Some(Ok(Message::Binary(b))) => {
                        if let Ok(t) = std::str::from_utf8(&b)
                            && let Some(n) = parse_notification::<K>(t, &subs) {
                                tasks.spawn(callback(n));
                            }
                    }
                    Some(Ok(Message::Pong(_))) | Some(Ok(Message::Ping(_))) => {}
                    Some(Ok(Message::Close(frame))) => {
                        return MessageLoopExit::ConnectionLost(format!("remote close: {frame:?}"));
                    }
                    Some(Ok(Message::Frame(_))) => {}
                    Some(Err(e)) => return MessageLoopExit::ConnectionLost(format!("ws read: {e}")),
                    None => return MessageLoopExit::ConnectionLost("ws stream ended".into()),
                }
            }
        }
    }
}

/// Sleep for the next backoff interval, or publish `Failed` and return false if
/// the retry budget is exhausted. Returns true if the caller should retry.
async fn retry_or_fail<K: SubKind>(
    phase: &'static str,
    reason: String,
    budget: &mut ReconnectBudget,
    cancel: &CancellationToken,
    status_tx: &watch::Sender<ConnectionStatus>,
) -> bool {
    if let Some(delay) = budget.next_backoff() {
        warn!(
            kind = K::LABEL,
            attempt = budget.attempt(),
            reason = %reason,
            ?delay,
            "subscription {phase}, retrying",
        );
        cancellable_sleep(delay, cancel).await
    } else {
        publish(
            status_tx,
            ConnectionStatus::Failed(format!("{phase}: {reason}")),
        );
        false
    }
}

fn publish(tx: &watch::Sender<ConnectionStatus>, status: ConnectionStatus) {
    tx.send_if_modified(|current| {
        if *current == status {
            false
        } else {
            *current = status;
            true
        }
    });
}

async fn connect_ws(rpc_endpoint: &str) -> Result<Ws, String> {
    let ws_url = http_to_ws_url(rpc_endpoint).map_err(|e| e.to_string())?;
    let request = ws_url
        .into_client_request()
        .map_err(|e| format!("build request: {e}"))?;

    let connect = tokio::time::timeout(CONNECT_TIMEOUT, connect_async(request))
        .await
        .map_err(|_| format!("connect timeout after {CONNECT_TIMEOUT:?}"))?
        .map_err(|e| format!("connect: {e}"))?;
    Ok(connect.0)
}

async fn subscribe<K: SubKind>(mut ws: Ws, program_ids: &[String]) -> Result<(Ws, Subs), String> {
    let mut subs = Subs::new();
    for (i, program_id) in program_ids.iter().enumerate() {
        let id = (i + 1) as u64;
        let req = serde_json::json!({
            "jsonrpc": "2.0",
            "id": id,
            "method": K::SUBSCRIBE_METHOD,
            "params": K::subscribe_params(program_id),
        });
        ws.send(Message::Text(req.to_string()))
            .await
            .map_err(|e| format!("subscribe send: {e}"))?;
        subs.insert(read_sub_ack(&mut ws, id).await?);
    }
    debug!(
        kind = K::LABEL,
        count = subs.len(),
        "subscriptions established"
    );
    Ok((ws, subs))
}

#[derive(Deserialize)]
struct SubAck {
    id: u64,
    result: Option<u64>,
    #[serde(default)]
    error: Option<serde_json::Value>,
}

async fn read_sub_ack(ws: &mut Ws, expected_id: u64) -> Result<u64, String> {
    let deadline = tokio::time::Instant::now() + HANDSHAKE_RESPONSE_TIMEOUT;
    loop {
        let msg = tokio::time::timeout_at(deadline, ws.next())
            .await
            .map_err(|_| format!("subscribe ack timeout after {HANDSHAKE_RESPONSE_TIMEOUT:?}"))?;

        let Some(msg) = msg else {
            return Err("ws ended during subscribe".into());
        };
        let msg = msg.map_err(|e| format!("ws read: {e}"))?;

        if let Message::Text(t) = msg
            && let Ok(ack) = serde_json::from_str::<SubAck>(&t)
        {
            if ack.id != expected_id {
                continue;
            }
            if let Some(err) = ack.error {
                return Err(format!("subscribe rejected: {err}"));
            }
            if let Some(sub_id) = ack.result {
                return Ok(sub_id);
            }
            return Err("subscribe ack missing result".into());
        }
    }
}

fn parse_notification<K: SubKind>(text: &str, subs: &Subs) -> Option<K::Notification> {
    #[derive(Deserialize)]
    #[serde(bound = "T: DeserializeOwned")]
    struct Msg<T> {
        method: String,
        params: Params<T>,
    }
    #[derive(Deserialize)]
    #[serde(bound = "T: DeserializeOwned")]
    struct Params<T> {
        subscription: u64,
        result: T,
    }

    let msg: Msg<K::Notification> = serde_json::from_str(text).ok()?;
    if msg.method != K::NOTIFICATION_METHOD {
        return None;
    }
    if !subs.contains(&msg.params.subscription) {
        return None;
    }
    Some(msg.params.result)
}