1use std::{collections::BTreeMap, time::Duration};
2
3use futures_util::{
4 future::{ready, BoxFuture, FutureExt},
5 sink::SinkExt,
6 stream::{self, BoxStream, StreamExt},
7};
8use log::*;
9use serde::de::DeserializeOwned;
10use serde_json::{json, Value};
11use solana_account_decoder_client_types::UiAccount;
12use solana_rpc_client_api::{
13 config::{
14 RpcAccountInfoConfig, RpcBlockSubscribeConfig, RpcBlockSubscribeFilter,
15 RpcProgramAccountsConfig, RpcSignatureSubscribeConfig, RpcTransactionLogsConfig,
16 RpcTransactionLogsFilter,
17 },
18 error_object::RpcErrorObject,
19 response::{
20 Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse,
21 RpcSignatureResult, RpcVote, SlotInfo, SlotUpdate,
22 },
23};
24use solana_sdk::{clock::Slot, pubkey::Pubkey, signature::Signature};
25use thiserror::Error;
26use tokio::{
27 sync::{
28 mpsc::{self, UnboundedSender},
29 oneshot,
30 },
31 task::JoinHandle,
32};
33use tokio_stream::wrappers::UnboundedReceiverStream;
34use tokio_tungstenite::{
35 connect_async,
36 tungstenite::{
37 protocol::frame::{coding::CloseCode, CloseFrame},
38 Message,
39 },
40};
41use url::Url;
42
43pub type PubsubClientResult<T = ()> = Result<T, PubsubClientError>;
44
45#[derive(Debug, Error)]
46pub enum PubsubClientError {
47 #[error("url parse error")]
48 UrlParseError(#[from] url::ParseError),
49
50 #[error("unable to connect to server")]
51 ConnectionError(tokio_tungstenite::tungstenite::Error),
52
53 #[error("websocket error")]
54 WsError(#[from] tokio_tungstenite::tungstenite::Error),
55
56 #[error("connection closed (({0})")]
57 ConnectionClosed(String),
58
59 #[error("json parse error")]
60 JsonParseError(#[from] serde_json::error::Error),
61
62 #[error("subscribe failed: {reason}")]
63 SubscribeFailed { reason: String, message: String },
64
65 #[error("unexpected message format: {0}")]
66 UnexpectedMessageError(String),
67
68 #[error("request failed: {reason}")]
69 RequestFailed { reason: String, message: String },
70
71 #[error("request error: {0}")]
72 RequestError(String),
73
74 #[error("could not find subscription id: {0}")]
75 UnexpectedSubscriptionResponse(String),
76
77 #[error("could not find node version: {0}")]
78 UnexpectedGetVersionResponse(String),
79}
80
81type UnsubscribeFn = Box<dyn FnOnce() -> BoxFuture<'static, ()> + Send>;
82type SubscribeResponseMsg =
83 Result<(mpsc::UnboundedReceiver<Value>, UnsubscribeFn), PubsubClientError>;
84type SubscribeRequestMsg = (String, Value, oneshot::Sender<SubscribeResponseMsg>);
85type SubscribeResult<'a, T> = PubsubClientResult<(BoxStream<'a, T>, UnsubscribeFn)>;
86type RequestMsg = (
87 String,
88 Value,
89 oneshot::Sender<Result<Value, PubsubClientError>>,
90);
91
92#[derive(Clone)]
93struct SubscriptionInfo {
94 sender: UnboundedSender<Value>,
95 payload: String,
96}
97
98#[derive(Debug)]
102pub struct PubsubClient {
103 subscribe_sender: mpsc::UnboundedSender<SubscribeRequestMsg>,
104 _request_sender: mpsc::UnboundedSender<RequestMsg>,
105 shutdown_sender: oneshot::Sender<()>,
106 ws: JoinHandle<Result<(), PubsubClientError>>,
107 url: Url,
108}
109
110impl PubsubClient {
111 pub async fn new(url: &str) -> PubsubClientResult<Self> {
112 let url = Url::parse(url)?;
113
114 let (subscribe_sender, subscribe_receiver) = mpsc::unbounded_channel();
115 let (_request_sender, request_receiver) = mpsc::unbounded_channel();
116 let (shutdown_sender, shutdown_receiver) = oneshot::channel();
117
118 let ws_handle = tokio::spawn(PubsubClient::run_ws(
120 url.clone(),
121 subscribe_receiver,
122 request_receiver,
123 shutdown_receiver,
124 ));
125
126 #[allow(clippy::used_underscore_binding)]
127 Ok(Self {
128 subscribe_sender,
129 _request_sender,
130 shutdown_sender,
131 ws: ws_handle,
132 url,
133 })
134 }
135
136 pub fn url(&self) -> Url {
138 self.url.clone()
139 }
140
141 pub fn is_running(&self) -> bool {
145 !self.ws.is_finished()
146 }
147
148 pub async fn shutdown(self) -> PubsubClientResult {
149 let _ = self.shutdown_sender.send(());
150 self.ws.await.unwrap() }
152
153 async fn subscribe<'a, T>(&self, operation: &str, params: Value) -> SubscribeResult<'a, T>
154 where
155 T: DeserializeOwned + Send + 'a,
156 {
157 let (response_sender, response_receiver) = oneshot::channel();
158 self.subscribe_sender
159 .send((operation.to_string(), params.clone(), response_sender))
160 .map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))?;
161
162 let (notifications, unsubscribe) = response_receiver
163 .await
164 .map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))??;
165
166 Ok((
167 UnboundedReceiverStream::new(notifications)
168 .filter_map(|value| ready(serde_json::from_value::<T>(value).ok()))
169 .boxed(),
170 unsubscribe,
171 ))
172 }
173
174 pub async fn account_subscribe(
184 &self,
185 pubkey: &Pubkey,
186 config: Option<RpcAccountInfoConfig>,
187 ) -> SubscribeResult<'_, RpcResponse<UiAccount>> {
188 let params = json!([pubkey.to_string(), config]);
189 self.subscribe("account", params).await
190 }
191
192 pub async fn block_subscribe(
205 &self,
206 filter: RpcBlockSubscribeFilter,
207 config: Option<RpcBlockSubscribeConfig>,
208 ) -> SubscribeResult<'_, RpcResponse<RpcBlockUpdate>> {
209 self.subscribe("block", json!([filter, config])).await
210 }
211
212 pub async fn logs_subscribe(
222 &self,
223 filter: RpcTransactionLogsFilter,
224 config: RpcTransactionLogsConfig,
225 ) -> SubscribeResult<'_, RpcResponse<RpcLogsResponse>> {
226 self.subscribe("logs", json!([filter, config])).await
227 }
228
229 pub async fn program_subscribe(
240 &self,
241 pubkey: &Pubkey,
242 config: Option<RpcProgramAccountsConfig>,
243 ) -> SubscribeResult<'_, RpcResponse<RpcKeyedAccount>> {
244 let params = json!([pubkey.to_string(), config]);
245 self.subscribe("program", params).await
246 }
247
248 pub async fn vote_subscribe(&self) -> SubscribeResult<'_, RpcVote> {
262 self.subscribe("vote", json!([])).await
263 }
264
265 pub async fn root_subscribe(&self) -> SubscribeResult<'_, Slot> {
278 self.subscribe("root", json!([])).await
279 }
280
281 pub async fn signature_subscribe(
295 &self,
296 signature: &Signature,
297 config: Option<RpcSignatureSubscribeConfig>,
298 ) -> SubscribeResult<'_, RpcResponse<RpcSignatureResult>> {
299 let params = json!([signature.to_string(), config]);
300 self.subscribe("signature", params).await
301 }
302
303 pub async fn slot_subscribe(&self) -> SubscribeResult<'_, SlotInfo> {
313 self.subscribe("slot", json!([])).await
314 }
315
316 pub async fn slot_updates_subscribe(&self) -> SubscribeResult<'_, SlotUpdate> {
331 self.subscribe("slotsUpdates", json!([])).await
332 }
333
334 async fn run_ws(
335 url: Url,
336 mut subscribe_receiver: mpsc::UnboundedReceiver<SubscribeRequestMsg>,
337 mut request_receiver: mpsc::UnboundedReceiver<RequestMsg>,
338 mut shutdown_receiver: oneshot::Receiver<()>,
339 ) -> PubsubClientResult {
340 let max_retry_count = 3;
343 let mut retry_count = 0;
344
345 let mut request_id: u64 = 0;
347 let mut subscriptions = BTreeMap::<u64, SubscriptionInfo>::new();
348 let mut request_id_to_sid = BTreeMap::<u64, u64>::new();
349 let (unsubscribe_sender, mut unsubscribe_receiver) = mpsc::unbounded_channel();
350
351 'reconnect: loop {
352 log::debug!(target: "ws", "PubsubClient connecting: {:?}", url.as_str());
353 let mut ws = match connect_async(url.as_str()).await {
354 Ok((ws, response)) => {
355 if response.status().is_server_error() || response.status().is_client_error() {
356 log::warn!(target: "ws", "couldn't reconnect: {response:?}");
357 retry_count += 1;
358 let delay = 2_u64.pow(2 + retry_count);
359 info!(target: "ws", "PubsubClient trying reconnect after {delay}s, attempt: {retry_count}/{max_retry_count}");
360 tokio::time::sleep(Duration::from_secs(delay)).await;
361 continue 'reconnect;
362 }
363
364 retry_count = 0;
365 ws
366 }
367 Err(err) => {
368 log::warn!(target: "ws", "couldn't reconnect: {err:?}");
369 if retry_count >= max_retry_count {
370 log::error!(target: "ws", "reached max reconnect attempts: {err:?}");
371 panic!("PubsubCliwnt reached max reconnect attempts: {err:?}");
372 }
373 retry_count += 1;
374 let delay = 2_u64.pow(2 + retry_count);
375 info!(target: "ws", "PubsubClient trying reconnect after {delay}s, attempt: {retry_count}/{max_retry_count}");
376 tokio::time::sleep(Duration::from_secs(delay)).await;
377 continue 'reconnect;
378 }
379 };
380
381 let mut inflight_subscribes =
382 BTreeMap::<u64, (String, String, oneshot::Sender<SubscribeResponseMsg>)>::new();
383 let mut inflight_unsubscribes = BTreeMap::<u64, oneshot::Sender<()>>::new();
384 let mut inflight_requests = BTreeMap::<u64, oneshot::Sender<_>>::new();
385
386 if !subscriptions.is_empty() {
388 info!(target: "ws", "resubscribing: {:?}", subscriptions.values().map(|x| x.payload.clone()).collect::<Vec<String>>());
389 if let Err(err) = ws
390 .send_all(&mut stream::iter(
391 subscriptions
392 .values()
393 .cloned()
394 .map(|s| Ok(Message::text(s.payload))),
395 ))
396 .await
397 {
398 error!(target: "ws", "PubsubClient failed resubscribing: {err:?}");
399 continue 'reconnect;
400 }
401 }
402
403 let mut liveness_check = tokio::time::interval(Duration::from_secs(60));
404 let _ = liveness_check.tick().await;
405
406 let mut heartbeat = tokio::time::interval(Duration::from_secs(30));
407 let _ = heartbeat.tick().await;
408
409 'manager: loop {
410 tokio::select! {
411 biased;
412 _ = (&mut shutdown_receiver) => {
414 log::info!(target: "ws", "PubsubClient received shutdown");
415 let frame = CloseFrame { code: CloseCode::Normal, reason: "".into() };
416 let _ = ws.send(Message::Close(Some(frame))).await;
417 let _ = ws.flush().await;
418 break 'reconnect Ok(());
419 },
420 next_msg = ws.next() => {
422 liveness_check.reset();
423 let msg = match next_msg {
424 Some(Ok(msg)) => msg,
425 Some(Err(err)) => {
426 log::warn!(target: "ws", "PubsubClient disconnected: {err:?}");
427 break 'manager
428 }
429 None => {
430 log::debug!(target: "ws", "PubsubClient disconnected");
431 break 'manager
432 },
433 };
434 trace!("ws.next(): {:?}", &msg);
435
436 let text = match msg {
438 Message::Text(ref text) => text,
439 Message::Close(_frame) => break 'manager,
440 Message::Ping(_) | Message::Pong(_) | Message::Binary(_) | Message::Frame(_) => continue 'manager,
441 };
442
443 let params = gjson::get(text, "params");
446 if params.exists() {
447 let sid = params.get("subscription").u64();
448 let mut unsubscribe_required = false;
449
450 if let Some(sub) = subscriptions.get(&sid) {
451 let result = params.get("result");
452 if result.exists() && sub.sender.send(serde_json::from_str(result.json()).expect("valid json")).is_err() {
453 unsubscribe_required = true;
454 }
455 } else {
456 unsubscribe_required = true;
457 }
458
459 if unsubscribe_required {
460 let method = gjson::get(text, "method");
461 if let Some(operation) = method.str().strip_suffix("Notification") {
462 let (response_sender, _response_receiver) = oneshot::channel();
463 let _ = unsubscribe_sender.send((operation.to_string(), sid, response_sender));
464 }
465 }
466 continue 'manager;
468 }
469
470 let id = gjson::get(text, "id");
473 if id.exists() {
474 let err = gjson::get(text, "error");
475 let err = if err.exists() {
476 match serde_json::from_str::<RpcErrorObject>(err.json()) {
477 Ok(rpc_error_object) => {
478 Some(format!("{} ({})", rpc_error_object.message, rpc_error_object.code))
479 }
480 Err(e) => Some(format!(
481 "Failed to deserialize RPC error response: {} [{e}]", err.str(),
482 ))
483 }
484 } else {
485 None
486 };
487
488 let id = id.u64();
489 if let Some(response_sender) = inflight_requests.remove(&id) {
490 match err {
491 Some(reason) => {
492 let _ = response_sender.send(Err(PubsubClientError::RequestFailed { reason, message: text.to_string()}));
493 },
494 None => {
495 let json_result = gjson::get(text, "result");
496 let json_result_value = if json_result.exists() {
497 Ok(serde_json::from_str::<Value>(json_result.json()).unwrap())
498 } else {
499 Err(PubsubClientError::RequestFailed { reason: "missing `result` field".into(), message: text.to_string() })
500 };
501
502 if let Err(err) = response_sender.send(json_result_value) {
503 log::warn!(target: "ws", "Ws request failed: {err:?}");
504 break 'manager;
505 }
506 }
507 }
508 } else if let Some(response_sender) = inflight_unsubscribes.remove(&id) {
509 let _ = response_sender.send(()); } else if let Some((operation, payload, response_sender)) = inflight_subscribes.remove(&id) {
511 match err {
512 Some(reason) => {
513 let _ = response_sender.send(Err(PubsubClientError::SubscribeFailed { reason, message: text.to_string()}));
514 },
515 None => {
516 let sid = gjson::get(text, "result");
518 if !sid.exists() {
519 return Err(PubsubClientError::SubscribeFailed { reason: "invalid `result` field".into(), message: text.to_string() });
520 }
521 let sid = sid.u64();
522
523 let (notifications_sender, notifications_receiver) = mpsc::unbounded_channel();
525 let unsubscribe_sender = unsubscribe_sender.clone();
526 let unsubscribe = Box::new(move || async move {
527 let (response_sender, response_receiver) = oneshot::channel();
528 if unsubscribe_sender.send((operation, id, response_sender)).is_ok() {
530 let _ = response_receiver.await; }
532 }.boxed());
533
534 if response_sender.send(Ok((notifications_receiver, unsubscribe))).is_err() {
535 break 'manager;
536 }
537 log::debug!(target: "ws", "subscription added: {sid:?}");
538 request_id_to_sid.insert(id, sid);
539 subscriptions.insert(sid, SubscriptionInfo {
540 sender: notifications_sender,
541 payload,
542 });
543 }
544 }
545 } else if let Some(previous_sid) = request_id_to_sid.remove(&id) {
546 match err {
547 Some(reason) => {
548 log::error!(target: "ws", "resubscription failed: {:?}, {reason:?}", text);
549 panic!();
550 },
551 None => {
552 let sid = gjson::get(text, "result");
554 if !sid.exists() {
555 log::error!(target: "ws", "resubscription failed. invalid `result` field: {:?}", text);
556 panic!();
557 }
558 let new_sid = sid.u64();
559
560 info!(target: "ws", "resubscribed: {previous_sid:>} => {new_sid:?}");
561 request_id_to_sid.insert(id, new_sid);
562 let info = subscriptions.remove(&previous_sid).unwrap();
563 subscriptions.insert(
564 new_sid,
565 info
566 );
567 }
568 }
569 } else {
570 error!(target: "ws", "PubSubClient received unknown request id: {id}");
571 break 'manager;
572 }
573 continue 'manager;
574 }
575 }
576 subscribe = subscribe_receiver.recv() => {
578 let (operation, params, response_sender) = subscribe.expect("subscribe channel");
579 request_id += 1;
580 let method = format!("{operation}Subscribe");
581 let text = json!({"jsonrpc":"2.0","id":request_id,"method":method,"params":params}).to_string();
582 if let Err(ref err) = ws.send(Message::Text(text.clone().into())).await {
583 log::warn!(target: "ws", "sending subscribe failed, {text}, {err:?}");
584 break 'manager;
585 }
586 inflight_subscribes.insert(request_id, (operation, text, response_sender));
587 },
588 unsubscribe = unsubscribe_receiver.recv() => {
590 let (operation, id, response_sender) = unsubscribe.expect("unsub channel");
591 if let Some(sid) = request_id_to_sid.remove(&id) {
592 subscriptions.remove(&sid);
593 request_id += 1;
594 let method = format!("{operation}Unsubscribe");
595 let text = json!({"jsonrpc":"2.0","id":request_id,"method":method,"params":[sid]}).to_string();
596 if let Err(err) = ws.send(Message::Text(text.clone().into())).await {
597 log::warn!(target: "ws", "sending unsubscribe failed: {text}, {err:?}");
598 }
599 inflight_unsubscribes.insert(request_id, response_sender);
600 }
601 },
602 request = request_receiver.recv() => {
604 let (method, params, response_sender) = request.expect("request channel");
605 request_id += 1;
606 let text = json!({"jsonrpc":"2.0","id":request_id,"method":method,"params":params}).to_string();
607 if let Err(err) = ws.send(Message::Text(text.into())).await {
608 log::warn!(target: "ws", "sending request failed. {err:?}");
609 }
610 inflight_requests.insert(request_id, response_sender);
611 },
612 _ = heartbeat.tick() => {
613 ws.send(Message::Ping(Default::default())).await?;
614 },
615 _ = liveness_check.tick() => {
616 warn!(target: "ws", "PubsubClient timed out");
617 break 'manager;
618 }
619 }
620 }
621 log::debug!(target: "ws", "manager finished");
622 }
623 }
624}