Skip to main content

tidepool_server/
ws.rs

1//! WebSocket server with polling-based subscription polyfills.
2//!
3//! Surfpool's native WS doesn't implement Solana's subscription
4//! methods, so `@solana/web3.js`'s `confirmTransaction()` and similar
5//! hang against it. We polyfill the commonly-needed subscriptions via
6//! periodic HTTP polling of the upstream RPC URL:
7//!
8//! - `signatureSubscribe` → `getSignatureStatuses` (one-shot; resolves
9//!   when the tx reaches the requested commitment).
10//! - `accountSubscribe` → `getAccountInfo` (long-lived; emits a
11//!   notification every time the account's observed state changes).
12//! - `logsSubscribe({mentions: [pubkey]})` → `getSignaturesForAddress` + `getTransaction` fan-out. Each new sig that mentions the given pubkey emits a `logsNotification` with the extracted log array. `all` / `allWithVotes` filters aren't polyfilled (no efficient polling shim) — clients asking for those get a typed error.
13//!
14//! Other subscription methods (`programSubscribe`, `slotSubscribe`,
15//! etc.) are not yet polyfilled.
16//!
17//! Per-connection state lives for the lifetime of the WS upgrade —
18//! when the client disconnects, all outstanding polling tasks are
19//! cancelled.
20
21use std::sync::{
22    atomic::{AtomicU64, Ordering},
23    Arc,
24};
25use std::time::Duration;
26
27use axum::{
28    extract::{
29        ws::{Message, WebSocket},
30        State, WebSocketUpgrade,
31    },
32    response::Response,
33    routing::get,
34    Router,
35};
36use futures_util::{SinkExt, StreamExt};
37use serde_json::{json, Value};
38use tokio::net::TcpListener;
39use tokio::sync::{mpsc, Mutex};
40use tokio::task::JoinHandle;
41use tracing::{info, warn};
42
43/// Poll interval for getSignatureStatuses. Matches the TS version.
44/// Solana finality is slow enough (~13s average) that 500 ms feels
45/// snappy without hammering the upstream.
46const POLL_INTERVAL: Duration = Duration::from_millis(500);
47
48/// How many polls before we give up — matches `confirmTransaction`'s
49/// client-side timeout roughly. Surfpool-local txs confirm in
50/// seconds; real mainnet queries might run longer, but 120s is a
51/// reasonable ceiling.
52const MAX_POLLS: u32 = 240; // 240 × 500 ms = 120 s
53
54/// Global subscription id counter. Scoped per-process so ids stay
55/// unique across connections — matches what real RPC nodes do.
56static NEXT_SUB_ID: AtomicU64 = AtomicU64::new(1);
57
58#[derive(Clone)]
59pub struct WsState {
60    pub upstream_url: String,
61    pub rpc_timeout: Duration,
62}
63
64/// Spawn the WS server on `port`. Returns a handle that can be
65/// awaited to block until the WS server shuts down — callers
66/// typically just `tokio::spawn` it and let it run alongside the HTTP
67/// server.
68pub async fn run_ws(
69    port: u16,
70    upstream_url: String,
71    rpc_timeout: Duration,
72) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
73    let state = WsState {
74        upstream_url,
75        rpc_timeout,
76    };
77    let app = Router::new().route("/", get(ws_upgrade)).with_state(state);
78    let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
79    let listener = TcpListener::bind(&addr).await?;
80    info!("tidepool WS listening on ws://{addr}");
81    axum::serve(listener, app).await?;
82    Ok(())
83}
84
85async fn ws_upgrade(ws: WebSocketUpgrade, State(state): State<WsState>) -> Response {
86    ws.on_upgrade(move |socket| handle_connection(socket, state))
87}
88
89/// One connection's lifetime. The main task owns the outbound sink;
90/// polling tasks forward notifications via an mpsc channel.
91#[allow(clippy::too_many_lines)]
92async fn handle_connection(socket: WebSocket, state: WsState) {
93    let (mut sink, mut stream) = socket.split();
94    let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
95
96    // Shared map of active subscriptions so signatureUnsubscribe can
97    // cancel the corresponding polling task.
98    let subs: Arc<Mutex<std::collections::HashMap<u64, JoinHandle<()>>>> =
99        Arc::new(Mutex::new(std::collections::HashMap::new()));
100
101    // Spawn a task that writes every outgoing message sequentially.
102    // Multiple polling tasks push into `tx`; ordering per-
103    // subscription is preserved; cross-subscription ordering doesn't
104    // matter (clients dispatch by subscription id).
105    let write_task = tokio::spawn(async move {
106        while let Some(msg) = rx.recv().await {
107            if sink.send(msg).await.is_err() {
108                break;
109            }
110        }
111    });
112
113    // Main read loop.
114    while let Some(Ok(msg)) = stream.next().await {
115        let Message::Text(text) = msg else {
116            // Binary / ping / pong / close — let axum's ws layer
117            // handle control frames; binary we don't use.
118            if matches!(msg, Message::Close(_)) {
119                break;
120            }
121            continue;
122        };
123
124        let Ok(req) = serde_json::from_str::<Value>(&text) else {
125            continue;
126        };
127        let method = req.get("method").and_then(Value::as_str).unwrap_or("");
128        let id = req.get("id").cloned().unwrap_or(Value::Null);
129
130        match method {
131            "signatureSubscribe" => {
132                let sub_id = NEXT_SUB_ID.fetch_add(1, Ordering::Relaxed);
133                let Some(signature) = req
134                    .get("params")
135                    .and_then(Value::as_array)
136                    .and_then(|a| a.first())
137                    .and_then(Value::as_str)
138                    .map(String::from)
139                else {
140                    send(&tx, &error_msg(&id, -32602, "missing signature param"));
141                    continue;
142                };
143                let commitment = req
144                    .get("params")
145                    .and_then(Value::as_array)
146                    .and_then(|a| a.get(1))
147                    .and_then(|v| v.get("commitment"))
148                    .and_then(Value::as_str)
149                    .unwrap_or("finalized")
150                    .to_string();
151
152                // Ack immediately with the subscription id.
153                send(
154                    &tx,
155                    &json!({ "jsonrpc": "2.0", "id": id, "result": sub_id }),
156                );
157
158                // Spawn the polling task.
159                let poll_tx = tx.clone();
160                let state_clone = state.clone();
161                let subs_clone = Arc::clone(&subs);
162                let handle = tokio::spawn(async move {
163                    poll_signature(sub_id, signature, commitment, state_clone, poll_tx).await;
164                    // Remove the sub from the map on completion so
165                    // signatureUnsubscribe doesn't try to abort a
166                    // finished task.
167                    subs_clone.lock().await.remove(&sub_id);
168                });
169                subs.lock().await.insert(sub_id, handle);
170            }
171
172            "accountSubscribe" => {
173                let sub_id = NEXT_SUB_ID.fetch_add(1, Ordering::Relaxed);
174                let Some(pubkey) = req
175                    .get("params")
176                    .and_then(Value::as_array)
177                    .and_then(|a| a.first())
178                    .and_then(Value::as_str)
179                    .map(String::from)
180                else {
181                    send(&tx, &error_msg(&id, -32602, "missing account pubkey param"));
182                    continue;
183                };
184                let opts = req
185                    .get("params")
186                    .and_then(Value::as_array)
187                    .and_then(|a| a.get(1))
188                    .cloned()
189                    .unwrap_or(Value::Null);
190                let commitment = opts
191                    .get("commitment")
192                    .and_then(Value::as_str)
193                    .unwrap_or("finalized")
194                    .to_string();
195                // Default Solana RPC encoding for accountSubscribe is
196                // base58; clients usually want base64 or jsonParsed.
197                let encoding = opts
198                    .get("encoding")
199                    .and_then(Value::as_str)
200                    .unwrap_or("base64")
201                    .to_string();
202
203                send(
204                    &tx,
205                    &json!({ "jsonrpc": "2.0", "id": id, "result": sub_id }),
206                );
207
208                let poll_tx = tx.clone();
209                let state_clone = state.clone();
210                let subs_clone = Arc::clone(&subs);
211                let handle = tokio::spawn(async move {
212                    poll_account(sub_id, pubkey, commitment, encoding, state_clone, poll_tx).await;
213                    subs_clone.lock().await.remove(&sub_id);
214                });
215                subs.lock().await.insert(sub_id, handle);
216            }
217
218            "logsSubscribe" => {
219                let sub_id = NEXT_SUB_ID.fetch_add(1, Ordering::Relaxed);
220                let params = req.get("params").and_then(Value::as_array);
221                let filter = params
222                    .and_then(|a| a.first())
223                    .cloned()
224                    .unwrap_or(Value::Null);
225                // Supported filter shapes: `{ mentions: [pubkey] }`.
226                // Reject `"all"` / `"allWithVotes"` with a typed error —
227                // there's no cheap polling shim for them.
228                let mention = match &filter {
229                    Value::Object(map) => map
230                        .get("mentions")
231                        .and_then(Value::as_array)
232                        .and_then(|a| a.first())
233                        .and_then(Value::as_str)
234                        .map(String::from),
235                    Value::String(s) if s == "all" || s == "allWithVotes" => {
236                        send(
237                            &tx,
238                            &error_msg(
239                                &id,
240                                -32601,
241                                "logsSubscribe with filter 'all' / 'allWithVotes' is not \
242                                 polyfilled by the tidepool WS shim; use { mentions: [pubkey] }",
243                            ),
244                        );
245                        continue;
246                    }
247                    _ => None,
248                };
249                let Some(mention) = mention else {
250                    send(
251                        &tx,
252                        &error_msg(
253                            &id,
254                            -32602,
255                            "logsSubscribe requires `{ mentions: [pubkey] }` filter",
256                        ),
257                    );
258                    continue;
259                };
260                let commitment = params
261                    .and_then(|a| a.get(1))
262                    .and_then(|v| v.get("commitment"))
263                    .and_then(Value::as_str)
264                    .unwrap_or("finalized")
265                    .to_string();
266
267                send(
268                    &tx,
269                    &json!({ "jsonrpc": "2.0", "id": id, "result": sub_id }),
270                );
271
272                let poll_tx = tx.clone();
273                let state_clone = state.clone();
274                let subs_clone = Arc::clone(&subs);
275                let handle = tokio::spawn(async move {
276                    poll_logs(sub_id, mention, commitment, state_clone, poll_tx).await;
277                    subs_clone.lock().await.remove(&sub_id);
278                });
279                subs.lock().await.insert(sub_id, handle);
280            }
281
282            "signatureUnsubscribe" | "accountUnsubscribe" | "logsUnsubscribe" => {
283                let Some(sub_id) = req
284                    .get("params")
285                    .and_then(Value::as_array)
286                    .and_then(|a| a.first())
287                    .and_then(Value::as_u64)
288                else {
289                    send(&tx, &error_msg(&id, -32602, "missing subscription id"));
290                    continue;
291                };
292                let removed = subs.lock().await.remove(&sub_id);
293                let was_present = removed.is_some();
294                if let Some(handle) = removed {
295                    handle.abort();
296                }
297                send(
298                    &tx,
299                    &json!({
300                        "jsonrpc": "2.0",
301                        "id": id,
302                        "result": was_present
303                    }),
304                );
305            }
306
307            // Every other method is silently dropped for now.
308            // Forwarding to upstream WS is a follow-up.
309            _ => {
310                send(
311                    &tx,
312                    &error_msg(
313                        &id,
314                        -32601,
315                        &format!("method '{method}' is not supported by the tidepool WS polyfill"),
316                    ),
317                );
318            }
319        }
320    }
321
322    // Client disconnected. Cancel every outstanding poll and drop the
323    // outbound channel (the write task exits when rx closes).
324    let mut subs = subs.lock().await;
325    for (_, handle) in subs.drain() {
326        handle.abort();
327    }
328    drop(tx);
329    let _ = write_task.await;
330}
331
332// ─── signature polling ──────────────────────────────────────────────
333
334async fn poll_signature(
335    sub_id: u64,
336    signature: String,
337    commitment: String,
338    state: WsState,
339    tx: mpsc::UnboundedSender<Message>,
340) {
341    let client = match reqwest::Client::builder()
342        .timeout(state.rpc_timeout)
343        .build()
344    {
345        Ok(c) => c,
346        Err(e) => {
347            warn!(err = %e, "failed to build reqwest client for ws polling");
348            return;
349        }
350    };
351    for _ in 0..MAX_POLLS {
352        tokio::time::sleep(POLL_INTERVAL).await;
353        let body = json!({
354            "jsonrpc": "2.0",
355            "id": 1,
356            "method": "getSignatureStatuses",
357            "params": [[signature], { "searchTransactionHistory": true }]
358        });
359        let Ok(resp) = client.post(&state.upstream_url).json(&body).send().await else {
360            continue;
361        };
362        let Ok(json): Result<Value, _> = resp.json().await else {
363            continue;
364        };
365        let Some(statuses) = json
366            .get("result")
367            .and_then(|r| r.get("value"))
368            .and_then(Value::as_array)
369        else {
370            continue;
371        };
372        let Some(status) = statuses.first() else {
373            continue;
374        };
375        if status.is_null() {
376            continue; // not yet seen
377        }
378        let status_conf = status
379            .get("confirmationStatus")
380            .and_then(Value::as_str)
381            .unwrap_or("");
382        if commitment_matches(&commitment, status_conf) {
383            // Emit notification and exit.
384            let notif = json!({
385                "jsonrpc": "2.0",
386                "method": "signatureNotification",
387                "params": {
388                    "result": {
389                        "context": json.get("result").and_then(|r| r.get("context")).cloned().unwrap_or(Value::Null),
390                        "value": { "err": status.get("err").cloned().unwrap_or(Value::Null) }
391                    },
392                    "subscription": sub_id
393                }
394            });
395            send(&tx, &notif);
396            return;
397        }
398    }
399    warn!(sub_id, signature, "signatureSubscribe poll timed out");
400}
401
402// ─── account polling ────────────────────────────────────────────────
403
404/// Long-lived account polling loop. Emits an `accountNotification`
405/// each time `getAccountInfo` returns a state that differs from the
406/// previously observed value. First poll emits the current state as
407/// the baseline; subsequent polls compare `{data, owner, lamports,
408/// executable, rentEpoch}` and only push on change.
409///
410/// Runs until the task is aborted (on `accountUnsubscribe` or client
411/// disconnect). Transient HTTP errors skip a cycle; we don't fail the
412/// subscription so clients stay connected across brief upstream
413/// flakes.
414async fn poll_account(
415    sub_id: u64,
416    pubkey: String,
417    commitment: String,
418    encoding: String,
419    state: WsState,
420    tx: mpsc::UnboundedSender<Message>,
421) {
422    let client = match reqwest::Client::builder()
423        .timeout(state.rpc_timeout)
424        .build()
425    {
426        Ok(c) => c,
427        Err(e) => {
428            warn!(err = %e, "failed to build reqwest client for account polling");
429            return;
430        }
431    };
432    let mut last: Option<Value> = None;
433    loop {
434        tokio::time::sleep(POLL_INTERVAL).await;
435        let body = json!({
436            "jsonrpc": "2.0",
437            "id": 1,
438            "method": "getAccountInfo",
439            "params": [pubkey, { "commitment": commitment, "encoding": encoding }]
440        });
441        let Ok(resp) = client.post(&state.upstream_url).json(&body).send().await else {
442            continue;
443        };
444        let Ok(json): Result<Value, _> = resp.json().await else {
445            continue;
446        };
447        let Some(result) = json.get("result") else {
448            continue;
449        };
450        // The `value` field is the account snapshot (may be Null when
451        // the account doesn't exist yet). `context` we include in the
452        // notification to match Helius / native Solana RPC shape.
453        let value = result.get("value").cloned().unwrap_or(Value::Null);
454        if last.as_ref() == Some(&value) {
455            continue;
456        }
457        last = Some(value.clone());
458        let notif = json!({
459            "jsonrpc": "2.0",
460            "method": "accountNotification",
461            "params": {
462                "result": {
463                    "context": result.get("context").cloned().unwrap_or(Value::Null),
464                    "value": value
465                },
466                "subscription": sub_id
467            }
468        });
469        send(&tx, &notif);
470    }
471}
472
473// ─── logs polling (mentions filter) ─────────────────────────────────
474
475/// Poll `getSignaturesForAddress(mention)` at `POLL_INTERVAL` and fan
476/// out to `getTransaction` for each new sig. Emits one
477/// `logsNotification` per new tx with the `logMessages` array extracted
478/// from meta. Runs until aborted.
479///
480/// Dedup strategy: remember the last-seen signature and page fresh
481/// results ahead of it. The first poll sets the baseline without
482/// emitting — matches Solana's "only notify on state change after
483/// subscribe" semantics for the other subscriptions.
484async fn poll_logs(
485    sub_id: u64,
486    mention: String,
487    commitment: String,
488    state: WsState,
489    tx: mpsc::UnboundedSender<Message>,
490) {
491    let client = match reqwest::Client::builder()
492        .timeout(state.rpc_timeout)
493        .build()
494    {
495        Ok(c) => c,
496        Err(e) => {
497            warn!(err = %e, "failed to build reqwest client for logs polling");
498            return;
499        }
500    };
501    let mut last_seen: Option<String> = None;
502    loop {
503        tokio::time::sleep(POLL_INTERVAL).await;
504        let sigs_body = json!({
505            "jsonrpc": "2.0",
506            "id": 1,
507            "method": "getSignaturesForAddress",
508            "params": [mention, { "commitment": commitment, "limit": 25 }]
509        });
510        let Ok(resp) = client
511            .post(&state.upstream_url)
512            .json(&sigs_body)
513            .send()
514            .await
515        else {
516            continue;
517        };
518        let Ok(json): Result<Value, _> = resp.json().await else {
519            continue;
520        };
521        let Some(entries) = json.get("result").and_then(Value::as_array) else {
522            continue;
523        };
524
525        // Collect new sigs in chronological order (upstream returns
526        // newest-first, so iterate reversed, stopping at the last-seen
527        // boundary).
528        let mut new_sigs: Vec<String> = Vec::new();
529        for entry in entries.iter().rev() {
530            let Some(sig) = entry.get("signature").and_then(Value::as_str) else {
531                continue;
532            };
533            if last_seen.as_deref() == Some(sig) {
534                new_sigs.clear();
535                continue;
536            }
537            new_sigs.push(sig.to_string());
538        }
539        // If we had no baseline, just set it and skip emitting — the
540        // subscription model emits on *future* events, not on the
541        // already-landed tx list.
542        if last_seen.is_none() {
543            if let Some(sig) = entries
544                .first()
545                .and_then(|e| e.get("signature"))
546                .and_then(Value::as_str)
547            {
548                last_seen = Some(sig.to_string());
549            }
550            continue;
551        }
552
553        for sig in &new_sigs {
554            if let Some(notif) =
555                fetch_logs_notification(&client, &state, &commitment, sub_id, sig).await
556            {
557                send(&tx, &notif);
558            }
559        }
560        if let Some(last) = new_sigs.last() {
561            last_seen = Some(last.clone());
562        }
563    }
564}
565
566/// One-shot `getTransaction` → `logsNotification` payload build. None
567/// on transient upstream error — caller continues polling.
568async fn fetch_logs_notification(
569    client: &reqwest::Client,
570    state: &WsState,
571    commitment: &str,
572    sub_id: u64,
573    signature: &str,
574) -> Option<Value> {
575    let body = json!({
576        "jsonrpc": "2.0",
577        "id": 1,
578        "method": "getTransaction",
579        "params": [
580            signature,
581            { "commitment": commitment, "encoding": "json", "maxSupportedTransactionVersion": 0 }
582        ]
583    });
584    let resp = client
585        .post(&state.upstream_url)
586        .json(&body)
587        .send()
588        .await
589        .ok()?;
590    let json: Value = resp.json().await.ok()?;
591    let result = json.get("result")?;
592    let slot = result.get("slot").and_then(Value::as_u64).unwrap_or(0);
593    let meta = result.get("meta").cloned().unwrap_or(Value::Null);
594    let err = meta.get("err").cloned().unwrap_or(Value::Null);
595    let logs = meta
596        .get("logMessages")
597        .cloned()
598        .unwrap_or(Value::Array(Vec::new()));
599    Some(json!({
600        "jsonrpc": "2.0",
601        "method": "logsNotification",
602        "params": {
603            "result": {
604                "context": { "slot": slot },
605                "value": {
606                    "signature": signature,
607                    "err": err,
608                    "logs": logs
609                }
610            },
611            "subscription": sub_id
612        }
613    }))
614}
615
616fn commitment_matches(requested: &str, actual: &str) -> bool {
617    // Solana's commitment ladder: processed < confirmed < finalized.
618    // If the request asked for `confirmed`, either `confirmed` or
619    // `finalized` actual satisfies. Same pattern Helius uses.
620    let rank = |s: &str| match s {
621        "processed" => 1,
622        "confirmed" => 2,
623        "finalized" => 3,
624        _ => 0,
625    };
626    rank(actual) >= rank(requested)
627}
628
629// ─── helpers ────────────────────────────────────────────────────────
630
631fn send(tx: &mpsc::UnboundedSender<Message>, value: &Value) {
632    let _ = tx.send(Message::Text(value.to_string().into()));
633}
634
635fn error_msg(id: &Value, code: i32, message: &str) -> Value {
636    json!({
637        "jsonrpc": "2.0",
638        "id": id,
639        "error": { "code": code, "message": message }
640    })
641}