iq_cometbft_rpc/client/transport/
websocket.rs

1//! WebSocket-based clients for accessing CometBFT RPC functionality.
2
3use alloc::{borrow::Cow, collections::BTreeMap as HashMap, fmt};
4use core::{
5    convert::{TryFrom, TryInto},
6    ops::Add,
7    str::FromStr,
8};
9
10use async_trait::async_trait;
11use async_tungstenite::{
12    tokio::ConnectStream,
13    tungstenite::{
14        protocol::{frame::coding::CloseCode, CloseFrame},
15        Message,
16    },
17    WebSocketStream,
18};
19use futures::{SinkExt, StreamExt};
20use serde::{Deserialize, Serialize};
21use tokio::time::{Duration, Instant};
22use tracing::{debug, error};
23
24use cometbft::{block::Height, Hash};
25use cometbft_config::net;
26
27use super::router::{SubscriptionId, SubscriptionIdRef};
28use crate::{
29    client::{
30        subscription::SubscriptionTx,
31        sync::{ChannelRx, ChannelTx},
32        transport::router::{PublishResult, SubscriptionRouter},
33        Client, CompatMode,
34    },
35    dialect::{v0_34, Dialect, LatestDialect},
36    endpoint::{self, subscribe, unsubscribe},
37    error::Error,
38    event::{self, Event},
39    prelude::*,
40    query::Query,
41    request::Wrapper,
42    response, Id, Order, Request, Response, Scheme, SimpleRequest, Subscription,
43    SubscriptionClient, Url,
44};
45
46// WebSocket connection times out if we haven't heard anything at all from the
47// server in this long.
48//
49// Taken from https://github.com/cometbft/cometbft/blob/309e29c245a01825fc9630103311fd04de99fa5e/rpc/jsonrpc/server/ws_handler.go#L27
50const RECV_TIMEOUT_SECONDS: u64 = 30;
51
52const RECV_TIMEOUT: Duration = Duration::from_secs(RECV_TIMEOUT_SECONDS);
53
54// How frequently to send ping messages to the WebSocket server.
55//
56// Taken from https://github.com/cometbft/cometbft/blob/309e29c245a01825fc9630103311fd04de99fa5e/rpc/jsonrpc/server/ws_handler.go#L28
57const PING_INTERVAL: Duration = Duration::from_secs((RECV_TIMEOUT_SECONDS * 9) / 10);
58
59/// Low-level WebSocket configuration
60pub use async_tungstenite::tungstenite::protocol::WebSocketConfig;
61
62/// CometBFT RPC client that provides access to all RPC functionality
63/// (including [`Event`] subscription) over a WebSocket connection.
64///
65/// The `WebSocketClient` itself is effectively just a handle to its driver
66/// The driver is the component of the client that actually interacts with the
67/// remote RPC over the WebSocket connection. The `WebSocketClient` can
68/// therefore be cloned into different asynchronous contexts, effectively
69/// allowing for asynchronous access to the driver.
70///
71/// It is the caller's responsibility to spawn an asynchronous task in which to
72/// execute the [`WebSocketClientDriver::run`] method. See the example below.
73///
74/// Dropping [`Subscription`]s will automatically terminate them (the
75/// `WebSocketClientDriver` detects a disconnected channel and removes the
76/// subscription from its internal routing table). When all subscriptions to a
77/// particular query have disconnected, the driver will automatically issue an
78/// unsubscribe request to the remote RPC endpoint.
79///
80/// ### Timeouts
81///
82/// The WebSocket client connection times out after 30 seconds if it does not
83/// receive anything at all from the server. This will automatically return
84/// errors to all active subscriptions and terminate them.
85///
86/// This is not configurable at present.
87///
88/// ### Keep-Alive
89///
90/// The WebSocket client implements a keep-alive mechanism whereby it sends a
91/// PING message to the server every 27 seconds, matching the PING cadence of
92/// the CometBFT server (see [this code][cometbft-websocket-ping] for
93/// details).
94///
95/// This is not configurable at present.
96///
97/// ## Examples
98///
99/// ```rust,ignore
100/// use cometbft::abci::Transaction;
101/// use cometbft_rpc::{WebSocketClient, SubscriptionClient, Client};
102/// use cometbft_rpc::query::EventType;
103/// use futures::StreamExt;
104///
105/// #[tokio::main]
106/// async fn main() {
107///     let (client, driver) = WebSocketClient::new("ws://127.0.0.1:26657/websocket")
108///         .await
109///         .unwrap();
110///     let driver_handle = tokio::spawn(async move { driver.run().await });
111///
112///     // Standard client functionality
113///     let tx = format!("some-key=some-value");
114///     client.broadcast_tx_async(Transaction::from(tx.into_bytes())).await.unwrap();
115///
116///     // Subscription functionality
117///     let mut subs = client.subscribe(EventType::NewBlock.into())
118///         .await
119///         .unwrap();
120///
121///     // Grab 5 NewBlock events
122///     let mut ev_count = 5_i32;
123///
124///     while let Some(res) = subs.next().await {
125///         let ev = res.unwrap();
126///         println!("Got event: {:?}", ev);
127///         ev_count -= 1;
128///         if ev_count < 0 {
129///             break;
130///         }
131///     }
132///
133///     // Signal to the driver to terminate.
134///     client.close().unwrap();
135///     // Await the driver's termination to ensure proper connection closure.
136///     let _ = driver_handle.await.unwrap();
137/// }
138/// ```
139///
140/// [cometbft-websocket-ping]: https://github.com/cometbft/cometbft/blob/309e29c245a01825fc9630103311fd04de99fa5e/rpc/jsonrpc/server/ws_handler.go#L28
141#[derive(Debug, Clone)]
142pub struct WebSocketClient {
143    inner: sealed::WebSocketClient,
144    compat: CompatMode,
145}
146
147/// The builder pattern constructor for [`WebSocketClient`].
148pub struct Builder {
149    url: WebSocketClientUrl,
150    compat: CompatMode,
151    transport_config: Option<WebSocketConfig>,
152}
153
154impl Builder {
155    /// Use the specified compatibility mode for the CometBFT RPC protocol.
156    ///
157    /// The default is the latest protocol version supported by this crate.
158    pub fn compat_mode(mut self, mode: CompatMode) -> Self {
159        self.compat = mode;
160        self
161    }
162
163    /// Use the specified low-level WebSocket configuration options.
164    pub fn config(mut self, config: WebSocketConfig) -> Self {
165        self.transport_config = Some(config);
166        self
167    }
168
169    /// Try to create a client with the options specified for this builder.
170    pub async fn build(self) -> Result<(WebSocketClient, WebSocketClientDriver), Error> {
171        let url = self.url.0;
172        let compat = self.compat;
173        let (inner, driver) = if url.is_secure() {
174            sealed::WebSocketClient::new_secure(url, compat, self.transport_config).await?
175        } else {
176            sealed::WebSocketClient::new_unsecure(url, compat, self.transport_config).await?
177        };
178
179        Ok((WebSocketClient { inner, compat }, driver))
180    }
181}
182
183impl WebSocketClient {
184    /// Construct a new WebSocket-based client connecting to the given
185    /// CometBFT node's RPC endpoint.
186    ///
187    /// Supports both `ws://` and `wss://` protocols.
188    pub async fn new<U>(url: U) -> Result<(Self, WebSocketClientDriver), Error>
189    where
190        U: TryInto<WebSocketClientUrl, Error = Error>,
191    {
192        let url = url.try_into()?;
193        Self::builder(url).build().await
194    }
195
196    /// Construct a new WebSocket-based client connecting to the given
197    /// CometBFT node's RPC endpoint.
198    ///
199    /// Supports both `ws://` and `wss://` protocols.
200    pub async fn new_with_config<U>(
201        url: U,
202        config: WebSocketConfig,
203    ) -> Result<(Self, WebSocketClientDriver), Error>
204    where
205        U: TryInto<WebSocketClientUrl, Error = Error>,
206    {
207        let url = url.try_into()?;
208        Self::builder(url).config(config).build().await
209    }
210
211    /// Initiate a builder for a WebSocket-based client connecting to the given
212    /// CometBFT node's RPC endpoint.
213    ///
214    /// Supports both `ws://` and `wss://` protocols.
215    pub fn builder(url: WebSocketClientUrl) -> Builder {
216        Builder {
217            url,
218            compat: Default::default(),
219            transport_config: Default::default(),
220        }
221    }
222
223    async fn perform_with_dialect<R, S>(&self, request: R, dialect: S) -> Result<R::Output, Error>
224    where
225        R: SimpleRequest<S>,
226        S: Dialect,
227    {
228        self.inner.perform(request, dialect).await
229    }
230}
231
232#[async_trait]
233impl Client for WebSocketClient {
234    async fn perform<R>(&self, request: R) -> Result<R::Output, Error>
235    where
236        R: SimpleRequest,
237    {
238        self.perform_with_dialect(request, LatestDialect).await
239    }
240
241    async fn block<H>(&self, height: H) -> Result<endpoint::block::Response, Error>
242    where
243        H: Into<Height> + Send,
244    {
245        perform_with_compat!(self, endpoint::block::Request::new(height.into()))
246    }
247
248    async fn block_by_hash(
249        &self,
250        hash: cometbft::Hash,
251    ) -> Result<endpoint::block_by_hash::Response, Error> {
252        perform_with_compat!(self, endpoint::block_by_hash::Request::new(hash))
253    }
254
255    async fn latest_block(&self) -> Result<endpoint::block::Response, Error> {
256        perform_with_compat!(self, endpoint::block::Request::default())
257    }
258
259    async fn block_results<H>(&self, height: H) -> Result<endpoint::block_results::Response, Error>
260    where
261        H: Into<Height> + Send,
262    {
263        perform_with_compat!(self, endpoint::block_results::Request::new(height.into()))
264    }
265
266    async fn latest_block_results(&self) -> Result<endpoint::block_results::Response, Error> {
267        perform_with_compat!(self, endpoint::block_results::Request::default())
268    }
269
270    async fn block_search(
271        &self,
272        query: Query,
273        page: u32,
274        per_page: u8,
275        order: Order,
276    ) -> Result<endpoint::block_search::Response, Error> {
277        perform_with_compat!(
278            self,
279            endpoint::block_search::Request::new(query, page, per_page, order)
280        )
281    }
282
283    async fn header<H>(&self, height: H) -> Result<endpoint::header::Response, Error>
284    where
285        H: Into<Height> + Send,
286    {
287        let height = height.into();
288        match self.compat {
289            CompatMode::V0_38 => self.perform(endpoint::header::Request::new(height)).await,
290            CompatMode::V0_37 => self.perform(endpoint::header::Request::new(height)).await,
291            CompatMode::V0_34 => {
292                // Back-fill with a request to /block endpoint and
293                // taking just the header from the response.
294                let resp = self
295                    .perform_with_dialect(endpoint::block::Request::new(height), v0_34::Dialect)
296                    .await?;
297                Ok(resp.into())
298            },
299        }
300    }
301
302    async fn header_by_hash(
303        &self,
304        hash: Hash,
305    ) -> Result<endpoint::header_by_hash::Response, Error> {
306        match self.compat {
307            CompatMode::V0_38 => {
308                self.perform(endpoint::header_by_hash::Request::new(hash))
309                    .await
310            },
311            CompatMode::V0_37 => {
312                self.perform(endpoint::header_by_hash::Request::new(hash))
313                    .await
314            },
315            CompatMode::V0_34 => {
316                // Back-fill with a request to /block_by_hash endpoint and
317                // taking just the header from the response.
318                let resp = self
319                    .perform_with_dialect(
320                        endpoint::block_by_hash::Request::new(hash),
321                        v0_34::Dialect,
322                    )
323                    .await?;
324                Ok(resp.into())
325            },
326        }
327    }
328
329    async fn tx(&self, hash: Hash, prove: bool) -> Result<endpoint::tx::Response, Error> {
330        perform_with_compat!(self, endpoint::tx::Request::new(hash, prove))
331    }
332
333    async fn tx_search(
334        &self,
335        query: Query,
336        prove: bool,
337        page: u32,
338        per_page: u8,
339        order: Order,
340    ) -> Result<endpoint::tx_search::Response, Error> {
341        perform_with_compat!(
342            self,
343            endpoint::tx_search::Request::new(query, prove, page, per_page, order)
344        )
345    }
346
347    async fn broadcast_tx_commit<T>(
348        &self,
349        tx: T,
350    ) -> Result<endpoint::broadcast::tx_commit::Response, Error>
351    where
352        T: Into<Vec<u8>> + Send,
353    {
354        perform_with_compat!(self, endpoint::broadcast::tx_commit::Request::new(tx))
355    }
356}
357
358#[async_trait]
359impl SubscriptionClient for WebSocketClient {
360    async fn subscribe(&self, query: Query) -> Result<Subscription, Error> {
361        self.inner.subscribe(query).await
362    }
363
364    async fn unsubscribe(&self, query: Query) -> Result<(), Error> {
365        self.inner.unsubscribe(query).await
366    }
367
368    fn close(self) -> Result<(), Error> {
369        self.inner.close()
370    }
371}
372
373/// A URL limited to use with WebSocket clients.
374///
375/// Facilitates useful type conversions and inferences.
376#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
377#[serde(transparent)]
378pub struct WebSocketClientUrl(Url);
379
380impl TryFrom<Url> for WebSocketClientUrl {
381    type Error = Error;
382
383    fn try_from(value: Url) -> Result<Self, Error> {
384        match value.scheme() {
385            Scheme::WebSocket | Scheme::SecureWebSocket => Ok(Self(value)),
386            _ => Err(Error::invalid_params(format!(
387                "cannot use URL {value} with WebSocket clients"
388            ))),
389        }
390    }
391}
392
393impl FromStr for WebSocketClientUrl {
394    type Err = Error;
395
396    fn from_str(s: &str) -> Result<Self, Error> {
397        let url: Url = s.parse()?;
398        url.try_into()
399    }
400}
401
402impl fmt::Display for WebSocketClientUrl {
403    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404        self.0.fmt(f)
405    }
406}
407
408impl TryFrom<&str> for WebSocketClientUrl {
409    type Error = Error;
410
411    fn try_from(value: &str) -> Result<Self, Error> {
412        value.parse()
413    }
414}
415
416impl TryFrom<net::Address> for WebSocketClientUrl {
417    type Error = Error;
418
419    fn try_from(value: net::Address) -> Result<Self, Error> {
420        match value {
421            net::Address::Tcp {
422                peer_id: _,
423                host,
424                port,
425            } => format!("ws://{host}:{port}/websocket").parse(),
426            net::Address::Unix { .. } => Err(Error::invalid_params(
427                "only TCP-based node addresses are supported".to_string(),
428            )),
429        }
430    }
431}
432
433impl From<WebSocketClientUrl> for Url {
434    fn from(url: WebSocketClientUrl) -> Self {
435        url.0
436    }
437}
438
439mod sealed {
440    use async_tungstenite::{
441        tokio::{connect_async_with_config, connect_async_with_tls_connector_and_config},
442        tungstenite::client::IntoClientRequest,
443    };
444    use tracing::debug;
445
446    use super::{
447        DriverCommand, SimpleRequestCommand, SubscribeCommand, UnsubscribeCommand,
448        WebSocketClientDriver, WebSocketConfig,
449    };
450    use crate::{
451        client::{
452            sync::{unbounded, ChannelTx},
453            transport::auth::authorize,
454            CompatMode,
455        },
456        dialect::Dialect,
457        prelude::*,
458        query::Query,
459        request::Wrapper,
460        utils::uuid_str,
461        Error, Response, SimpleRequest, Subscription, Url,
462    };
463
464    /// Marker for the [`AsyncTungsteniteClient`] for clients operating over
465    /// unsecure connections.
466    #[derive(Debug, Clone)]
467    pub struct Unsecure;
468
469    /// Marker for the [`AsyncTungsteniteClient`] for clients operating over
470    /// secure connections.
471    #[derive(Debug, Clone)]
472    pub struct Secure;
473
474    /// An [`async-tungstenite`]-based WebSocket client.
475    ///
476    /// Different modes of operation (secure and unsecure) are facilitated by
477    /// different variants of this type.
478    ///
479    /// [`async-tungstenite`]: https://crates.io/crates/async-tungstenite
480    #[derive(Debug, Clone)]
481    pub struct AsyncTungsteniteClient<C> {
482        cmd_tx: ChannelTx<DriverCommand>,
483        _client_type: core::marker::PhantomData<C>,
484    }
485
486    impl AsyncTungsteniteClient<Unsecure> {
487        /// Construct a WebSocket client. Immediately attempts to open a WebSocket
488        /// connection to the node with the given address.
489        ///
490        /// On success, this returns both a client handle (a `WebSocketClient`
491        /// instance) as well as the WebSocket connection driver. The execution of
492        /// this driver becomes the responsibility of the client owner, and must be
493        /// executed in a separate asynchronous context to the client to ensure it
494        /// doesn't block the client.
495        pub async fn new(
496            url: Url,
497            compat: CompatMode,
498            config: Option<WebSocketConfig>,
499        ) -> Result<(Self, WebSocketClientDriver), Error> {
500            debug!("Connecting to unsecure WebSocket endpoint: {}", url);
501
502            let (stream, _response) = connect_async_with_config(url, config)
503                .await
504                .map_err(Error::tungstenite)?;
505
506            let (cmd_tx, cmd_rx) = unbounded();
507            let driver = WebSocketClientDriver::new(stream, cmd_rx, compat);
508            let client = Self {
509                cmd_tx,
510                _client_type: Default::default(),
511            };
512
513            Ok((client, driver))
514        }
515    }
516
517    impl AsyncTungsteniteClient<Secure> {
518        /// Construct a WebSocket client. Immediately attempts to open a WebSocket
519        /// connection to the node with the given address, but over a secure
520        /// connection.
521        ///
522        /// On success, this returns both a client handle (a `WebSocketClient`
523        /// instance) as well as the WebSocket connection driver. The execution of
524        /// this driver becomes the responsibility of the client owner, and must be
525        /// executed in a separate asynchronous context to the client to ensure it
526        /// doesn't block the client.
527        pub async fn new(
528            url: Url,
529            compat: CompatMode,
530            config: Option<WebSocketConfig>,
531        ) -> Result<(Self, WebSocketClientDriver), Error> {
532            debug!("Connecting to secure WebSocket endpoint: {}", url);
533
534            // Not supplying a connector means async_tungstenite will create the
535            // connector for us.
536            let (stream, _response) =
537                connect_async_with_tls_connector_and_config(url, None, config)
538                    .await
539                    .map_err(Error::tungstenite)?;
540
541            let (cmd_tx, cmd_rx) = unbounded();
542            let driver = WebSocketClientDriver::new(stream, cmd_rx, compat);
543            let client = Self {
544                cmd_tx,
545                _client_type: Default::default(),
546            };
547
548            Ok((client, driver))
549        }
550    }
551
552    impl<C> AsyncTungsteniteClient<C> {
553        fn send_cmd(&self, cmd: DriverCommand) -> Result<(), Error> {
554            self.cmd_tx.send(cmd)
555        }
556
557        /// Signals to the driver that it must terminate.
558        pub fn close(self) -> Result<(), Error> {
559            self.send_cmd(DriverCommand::Terminate)
560        }
561    }
562
563    impl<C> AsyncTungsteniteClient<C> {
564        pub async fn perform<R, S>(&self, request: R) -> Result<R::Output, Error>
565        where
566            R: SimpleRequest<S>,
567            S: Dialect,
568        {
569            let wrapper = Wrapper::new(request);
570            let id = wrapper.id().to_string();
571            let wrapped_request = wrapper.into_json();
572
573            tracing::debug!("Outgoing request: {}", wrapped_request);
574
575            let (response_tx, mut response_rx) = unbounded();
576
577            self.send_cmd(DriverCommand::SimpleRequest(SimpleRequestCommand {
578                id,
579                wrapped_request,
580                response_tx,
581            }))?;
582
583            let response = response_rx.recv().await.ok_or_else(|| {
584                Error::client_internal("failed to hear back from WebSocket driver".to_string())
585            })??;
586
587            tracing::debug!("Incoming response: {}", response);
588
589            R::Response::from_string(response).map(Into::into)
590        }
591
592        pub async fn subscribe(&self, query: Query) -> Result<Subscription, Error> {
593            let (subscription_tx, subscription_rx) = unbounded();
594            let (response_tx, mut response_rx) = unbounded();
595            // By default we use UUIDs to differentiate subscriptions
596            let id = uuid_str();
597            self.send_cmd(DriverCommand::Subscribe(SubscribeCommand {
598                id: id.to_string(),
599                query: query.to_string(),
600                subscription_tx,
601                response_tx,
602            }))?;
603            // Make sure our subscription request went through successfully.
604            response_rx.recv().await.ok_or_else(|| {
605                Error::client_internal("failed to hear back from WebSocket driver".to_string())
606            })??;
607            Ok(Subscription::new(id, query, subscription_rx))
608        }
609
610        pub async fn unsubscribe(&self, query: Query) -> Result<(), Error> {
611            let (response_tx, mut response_rx) = unbounded();
612            self.send_cmd(DriverCommand::Unsubscribe(UnsubscribeCommand {
613                query: query.to_string(),
614                response_tx,
615            }))?;
616            response_rx.recv().await.ok_or_else(|| {
617                Error::client_internal("failed to hear back from WebSocket driver".to_string())
618            })??;
619            Ok(())
620        }
621    }
622
623    /// Allows us to erase the type signatures associated with the different
624    /// WebSocket client variants.
625    #[derive(Debug, Clone)]
626    pub enum WebSocketClient {
627        Unsecure(AsyncTungsteniteClient<Unsecure>),
628        Secure(AsyncTungsteniteClient<Secure>),
629    }
630
631    impl WebSocketClient {
632        pub async fn new_unsecure(
633            url: Url,
634            compat: CompatMode,
635            config: Option<WebSocketConfig>,
636        ) -> Result<(Self, WebSocketClientDriver), Error> {
637            let (client, driver) =
638                AsyncTungsteniteClient::<Unsecure>::new(url, compat, config).await?;
639            Ok((Self::Unsecure(client), driver))
640        }
641
642        pub async fn new_secure(
643            url: Url,
644            compat: CompatMode,
645            config: Option<WebSocketConfig>,
646        ) -> Result<(Self, WebSocketClientDriver), Error> {
647            let (client, driver) =
648                AsyncTungsteniteClient::<Secure>::new(url, compat, config).await?;
649            Ok((Self::Secure(client), driver))
650        }
651
652        pub fn close(self) -> Result<(), Error> {
653            match self {
654                WebSocketClient::Unsecure(c) => c.close(),
655                WebSocketClient::Secure(c) => c.close(),
656            }
657        }
658    }
659
660    impl WebSocketClient {
661        pub async fn perform<R, S>(&self, request: R, _dialect: S) -> Result<R::Output, Error>
662        where
663            R: SimpleRequest<S>,
664            S: Dialect,
665        {
666            match self {
667                WebSocketClient::Unsecure(c) => c.perform(request).await,
668                WebSocketClient::Secure(c) => c.perform(request).await,
669            }
670        }
671
672        pub async fn subscribe(&self, query: Query) -> Result<Subscription, Error> {
673            match self {
674                WebSocketClient::Unsecure(c) => c.subscribe(query).await,
675                WebSocketClient::Secure(c) => c.subscribe(query).await,
676            }
677        }
678
679        pub async fn unsubscribe(&self, query: Query) -> Result<(), Error> {
680            match self {
681                WebSocketClient::Unsecure(c) => c.unsubscribe(query).await,
682                WebSocketClient::Secure(c) => c.unsubscribe(query).await,
683            }
684        }
685    }
686
687    use async_tungstenite::tungstenite;
688
689    impl IntoClientRequest for Url {
690        fn into_client_request(
691            self,
692        ) -> tungstenite::Result<tungstenite::handshake::client::Request> {
693            let builder = tungstenite::handshake::client::Request::builder()
694                .method("GET")
695                .header("Host", self.host())
696                .header("Connection", "Upgrade")
697                .header("Upgrade", "websocket")
698                .header("Sec-WebSocket-Version", "13")
699                .header(
700                    "Sec-WebSocket-Key",
701                    tungstenite::handshake::client::generate_key(),
702                );
703
704            let builder = if let Some(auth) = authorize(self.as_ref()) {
705                builder.header("Authorization", auth.to_string())
706            } else {
707                builder
708            };
709
710            builder
711                .uri(self.to_string())
712                .body(())
713                .map_err(tungstenite::error::Error::HttpFormat)
714        }
715    }
716}
717
718// The different types of commands that can be sent from the WebSocketClient to
719// the driver.
720#[derive(Debug, Clone)]
721enum DriverCommand {
722    // Initiate a subscription request.
723    Subscribe(SubscribeCommand),
724    // Initiate an unsubscribe request.
725    Unsubscribe(UnsubscribeCommand),
726    // For non-subscription-related requests.
727    SimpleRequest(SimpleRequestCommand),
728    Terminate,
729}
730
731#[derive(Debug, Clone)]
732struct SubscribeCommand {
733    // The desired ID for the outgoing JSON-RPC request.
734    id: String,
735    // The query for which we want to receive events.
736    query: String,
737    // Where to send subscription events.
738    subscription_tx: SubscriptionTx,
739    // Where to send the result of the subscription request.
740    response_tx: ChannelTx<Result<(), Error>>,
741}
742
743#[derive(Debug, Clone)]
744struct UnsubscribeCommand {
745    // The query from which to unsubscribe.
746    query: String,
747    // Where to send the result of the unsubscribe request.
748    response_tx: ChannelTx<Result<(), Error>>,
749}
750
751#[derive(Debug, Clone)]
752struct SimpleRequestCommand {
753    // The desired ID for the outgoing JSON-RPC request. Technically we
754    // could extract this from the wrapped request, but that would mean
755    // additional unnecessary computational resources for deserialization.
756    id: String,
757    // The wrapped and serialized JSON-RPC request.
758    wrapped_request: String,
759    // Where to send the result of the simple request.
760    response_tx: ChannelTx<Result<String, Error>>,
761}
762
763#[derive(Serialize, Deserialize, Debug, Clone)]
764struct GenericJsonResponse(serde_json::Value);
765
766impl Response for GenericJsonResponse {}
767
768/// Drives the WebSocket connection for a `WebSocketClient` instance.
769///
770/// This is the primary component responsible for transport-level interaction
771/// with the remote WebSocket endpoint.
772pub struct WebSocketClientDriver {
773    // The underlying WebSocket network connection.
774    stream: WebSocketStream<ConnectStream>,
775    // Facilitates routing of events to their respective subscriptions.
776    router: SubscriptionRouter,
777    // How we receive incoming commands from the WebSocketClient.
778    cmd_rx: ChannelRx<DriverCommand>,
779    // Commands we've received but have not yet completed, indexed by their ID.
780    // A Terminate command is executed immediately.
781    pending_commands: HashMap<SubscriptionId, DriverCommand>,
782    // The compatibility mode directing how to parse subscription events.
783    compat: CompatMode,
784}
785
786impl WebSocketClientDriver {
787    fn new(
788        stream: WebSocketStream<ConnectStream>,
789        cmd_rx: ChannelRx<DriverCommand>,
790        compat: CompatMode,
791    ) -> Self {
792        Self {
793            stream,
794            router: SubscriptionRouter::default(),
795            cmd_rx,
796            pending_commands: HashMap::new(),
797            compat,
798        }
799    }
800
801    async fn send_msg(&mut self, msg: Message) -> Result<(), Error> {
802        self.stream.send(msg).await.map_err(|e| {
803            Error::web_socket("failed to write to WebSocket connection".to_string(), e)
804        })
805    }
806
807    async fn simple_request(&mut self, cmd: SimpleRequestCommand) -> Result<(), Error> {
808        if let Err(e) = self
809            .send_msg(Message::Text(cmd.wrapped_request.clone()))
810            .await
811        {
812            cmd.response_tx.send(Err(e.clone()))?;
813            return Err(e);
814        }
815        self.pending_commands
816            .insert(cmd.id.clone(), DriverCommand::SimpleRequest(cmd));
817        Ok(())
818    }
819
820    /// Executes the WebSocket driver, which manages the underlying WebSocket
821    /// transport.
822    pub async fn run(mut self) -> Result<(), Error> {
823        let mut ping_interval =
824            tokio::time::interval_at(Instant::now().add(PING_INTERVAL), PING_INTERVAL);
825
826        let recv_timeout = tokio::time::sleep(RECV_TIMEOUT);
827        tokio::pin!(recv_timeout);
828
829        loop {
830            tokio::select! {
831                Some(res) = self.stream.next() => match res {
832                    Ok(msg) => {
833                        // Reset the receive timeout every time we successfully
834                        // receive a message from the remote endpoint.
835                        recv_timeout.as_mut().reset(Instant::now().add(RECV_TIMEOUT));
836                        self.handle_incoming_msg(msg).await?
837                    },
838                    Err(e) => return Err(
839                        Error::web_socket(
840                            "failed to read from WebSocket connection".to_string(),
841                            e
842                        ),
843                    ),
844                },
845                Some(cmd) = self.cmd_rx.recv() => match cmd {
846                    DriverCommand::Subscribe(subs_cmd) => self.subscribe(subs_cmd).await?,
847                    DriverCommand::Unsubscribe(unsubs_cmd) => self.unsubscribe(unsubs_cmd).await?,
848                    DriverCommand::SimpleRequest(req_cmd) => self.simple_request(req_cmd).await?,
849                    DriverCommand::Terminate => return self.close().await,
850                },
851                _ = ping_interval.tick() => self.ping().await?,
852                _ = &mut recv_timeout => {
853                    return Err(Error::web_socket_timeout(RECV_TIMEOUT));
854                }
855            }
856        }
857    }
858
859    async fn send_request<R>(&mut self, wrapper: Wrapper<R>) -> Result<(), Error>
860    where
861        R: Request,
862    {
863        self.send_msg(Message::Text(
864            serde_json::to_string_pretty(&wrapper).unwrap(),
865        ))
866        .await
867    }
868
869    async fn subscribe(&mut self, cmd: SubscribeCommand) -> Result<(), Error> {
870        // If we already have an active subscription for the given query,
871        // there's no need to initiate another one. Just add this subscription
872        // to the router.
873        if self.router.num_subscriptions_for_query(cmd.query.clone()) > 0 {
874            let (id, query, subscription_tx, response_tx) =
875                (cmd.id, cmd.query, cmd.subscription_tx, cmd.response_tx);
876            self.router.add(id, query, subscription_tx);
877            return response_tx.send(Ok(()));
878        }
879
880        // Otherwise, we need to initiate a subscription request.
881        let wrapper = Wrapper::new_with_id(
882            Id::Str(cmd.id.clone()),
883            subscribe::Request::new(cmd.query.clone()),
884        );
885        if let Err(e) = self.send_request(wrapper).await {
886            cmd.response_tx.send(Err(e.clone()))?;
887            return Err(e);
888        }
889        self.pending_commands
890            .insert(cmd.id.clone(), DriverCommand::Subscribe(cmd));
891        Ok(())
892    }
893
894    async fn unsubscribe(&mut self, cmd: UnsubscribeCommand) -> Result<(), Error> {
895        // Terminate all subscriptions for this query immediately. This
896        // prioritizes acknowledgement of the caller's wishes over networking
897        // problems.
898        if self.router.remove_by_query(cmd.query.clone()) == 0 {
899            // If there were no subscriptions for this query, respond
900            // immediately.
901            cmd.response_tx.send(Ok(()))?;
902            return Ok(());
903        }
904
905        // Unsubscribe requests can (and probably should) have distinct
906        // JSON-RPC IDs as compared to their subscription IDs.
907        let wrapper = Wrapper::new(unsubscribe::Request::new(cmd.query.clone()));
908        let req_id = wrapper.id().clone();
909        if let Err(e) = self.send_request(wrapper).await {
910            cmd.response_tx.send(Err(e.clone()))?;
911            return Err(e);
912        }
913        self.pending_commands
914            .insert(req_id.to_string(), DriverCommand::Unsubscribe(cmd));
915        Ok(())
916    }
917
918    async fn handle_incoming_msg(&mut self, msg: Message) -> Result<(), Error> {
919        match msg {
920            Message::Text(s) => self.handle_text_msg(s).await,
921            Message::Ping(v) => self.pong(v).await,
922            _ => Ok(()),
923        }
924    }
925
926    async fn handle_text_msg(&mut self, msg: String) -> Result<(), Error> {
927        let parse_res = match self.compat {
928            CompatMode::V0_38 => event::v0_38::DeEvent::from_string(&msg).map(Into::into),
929            CompatMode::V0_37 => event::v1::DeEvent::from_string(&msg).map(Into::into),
930            CompatMode::V0_34 => event::v0_34::DeEvent::from_string(&msg).map(Into::into),
931        };
932        if let Ok(ev) = parse_res {
933            debug!("JSON-RPC event: {}", msg);
934            self.publish_event(ev).await;
935            return Ok(());
936        }
937
938        let wrapper: response::Wrapper<GenericJsonResponse> = match serde_json::from_str(&msg) {
939            Ok(w) => w,
940            Err(e) => {
941                error!(
942                    "Failed to deserialize incoming message as a JSON-RPC message: {}",
943                    e
944                );
945
946                debug!("JSON-RPC message: {}", msg);
947
948                return Ok(());
949            },
950        };
951
952        debug!("Generic JSON-RPC message: {:?}", wrapper);
953
954        let id = wrapper.id().to_string();
955
956        if let Some(e) = wrapper.into_error() {
957            self.publish_error(&id, e).await;
958        }
959
960        if let Some(pending_cmd) = self.pending_commands.remove(&id) {
961            self.respond_to_pending_command(pending_cmd, msg).await?;
962        };
963
964        // We ignore incoming messages whose ID we don't recognize (could be
965        // relating to a fire-and-forget unsubscribe request - see the
966        // publish_event() method below).
967        Ok(())
968    }
969
970    async fn publish_error(&mut self, id: SubscriptionIdRef<'_>, err: Error) {
971        if let PublishResult::AllDisconnected(query) = self.router.publish_error(id, err) {
972            debug!(
973                "All subscribers for query \"{}\" have disconnected. Unsubscribing from query...",
974                query
975            );
976
977            // If all subscribers have disconnected for this query, we need to
978            // unsubscribe from it. We issue a fire-and-forget unsubscribe
979            // message.
980            if let Err(e) = self
981                .send_request(Wrapper::new(unsubscribe::Request::new(query)))
982                .await
983            {
984                error!("Failed to send unsubscribe request: {}", e);
985            }
986        }
987    }
988
989    async fn publish_event(&mut self, ev: Event) {
990        if let PublishResult::AllDisconnected(query) = self.router.publish_event(ev) {
991            debug!(
992                "All subscribers for query \"{}\" have disconnected. Unsubscribing from query...",
993                query
994            );
995
996            // If all subscribers have disconnected for this query, we need to
997            // unsubscribe from it. We issue a fire-and-forget unsubscribe
998            // message.
999            if let Err(e) = self
1000                .send_request(Wrapper::new(unsubscribe::Request::new(query)))
1001                .await
1002            {
1003                error!("Failed to send unsubscribe request: {}", e);
1004            }
1005        }
1006    }
1007
1008    async fn respond_to_pending_command(
1009        &mut self,
1010        pending_cmd: DriverCommand,
1011        response: String,
1012    ) -> Result<(), Error> {
1013        match pending_cmd {
1014            DriverCommand::Subscribe(cmd) => {
1015                let (id, query, subscription_tx, response_tx) =
1016                    (cmd.id, cmd.query, cmd.subscription_tx, cmd.response_tx);
1017                self.router.add(id, query, subscription_tx);
1018                response_tx.send(Ok(()))
1019            },
1020            DriverCommand::Unsubscribe(cmd) => cmd.response_tx.send(Ok(())),
1021            DriverCommand::SimpleRequest(cmd) => cmd.response_tx.send(Ok(response)),
1022            _ => Ok(()),
1023        }
1024    }
1025
1026    async fn pong(&mut self, v: Vec<u8>) -> Result<(), Error> {
1027        self.send_msg(Message::Pong(v)).await
1028    }
1029
1030    async fn ping(&mut self) -> Result<(), Error> {
1031        self.send_msg(Message::Ping(Vec::new())).await
1032    }
1033
1034    async fn close(mut self) -> Result<(), Error> {
1035        self.send_msg(Message::Close(Some(CloseFrame {
1036            code: CloseCode::Normal,
1037            reason: Cow::from("client closed WebSocket connection"),
1038        })))
1039        .await?;
1040
1041        while let Some(res) = self.stream.next().await {
1042            if res.is_err() {
1043                return Ok(());
1044            }
1045        }
1046        Ok(())
1047    }
1048}