1use alloc::{borrow::Cow, collections::BTreeMap as HashMap, fmt};
4use core::{
5 convert::{TryFrom, TryInto},
6 ops::Add,
7 str::FromStr,
8};
9
10use async_trait::async_trait;
11use async_tungstenite::{
12 tokio::ConnectStream,
13 tungstenite::{
14 protocol::{frame::coding::CloseCode, CloseFrame},
15 Message,
16 },
17 WebSocketStream,
18};
19use futures::{SinkExt, StreamExt};
20use serde::{Deserialize, Serialize};
21use tokio::time::{Duration, Instant};
22use tracing::{debug, error};
23
24use cometbft::{block::Height, Hash};
25use cometbft_config::net;
26
27use super::router::{SubscriptionId, SubscriptionIdRef};
28use crate::{
29 client::{
30 subscription::SubscriptionTx,
31 sync::{ChannelRx, ChannelTx},
32 transport::router::{PublishResult, SubscriptionRouter},
33 Client, CompatMode,
34 },
35 dialect::{v0_34, Dialect, LatestDialect},
36 endpoint::{self, subscribe, unsubscribe},
37 error::Error,
38 event::{self, Event},
39 prelude::*,
40 query::Query,
41 request::Wrapper,
42 response, Id, Order, Request, Response, Scheme, SimpleRequest, Subscription,
43 SubscriptionClient, Url,
44};
45
46const RECV_TIMEOUT_SECONDS: u64 = 30;
51
52const RECV_TIMEOUT: Duration = Duration::from_secs(RECV_TIMEOUT_SECONDS);
53
54const PING_INTERVAL: Duration = Duration::from_secs((RECV_TIMEOUT_SECONDS * 9) / 10);
58
59pub use async_tungstenite::tungstenite::protocol::WebSocketConfig;
61
62#[derive(Debug, Clone)]
142pub struct WebSocketClient {
143 inner: sealed::WebSocketClient,
144 compat: CompatMode,
145}
146
147pub struct Builder {
149 url: WebSocketClientUrl,
150 compat: CompatMode,
151 transport_config: Option<WebSocketConfig>,
152}
153
154impl Builder {
155 pub fn compat_mode(mut self, mode: CompatMode) -> Self {
159 self.compat = mode;
160 self
161 }
162
163 pub fn config(mut self, config: WebSocketConfig) -> Self {
165 self.transport_config = Some(config);
166 self
167 }
168
169 pub async fn build(self) -> Result<(WebSocketClient, WebSocketClientDriver), Error> {
171 let url = self.url.0;
172 let compat = self.compat;
173 let (inner, driver) = if url.is_secure() {
174 sealed::WebSocketClient::new_secure(url, compat, self.transport_config).await?
175 } else {
176 sealed::WebSocketClient::new_unsecure(url, compat, self.transport_config).await?
177 };
178
179 Ok((WebSocketClient { inner, compat }, driver))
180 }
181}
182
183impl WebSocketClient {
184 pub async fn new<U>(url: U) -> Result<(Self, WebSocketClientDriver), Error>
189 where
190 U: TryInto<WebSocketClientUrl, Error = Error>,
191 {
192 let url = url.try_into()?;
193 Self::builder(url).build().await
194 }
195
196 pub async fn new_with_config<U>(
201 url: U,
202 config: WebSocketConfig,
203 ) -> Result<(Self, WebSocketClientDriver), Error>
204 where
205 U: TryInto<WebSocketClientUrl, Error = Error>,
206 {
207 let url = url.try_into()?;
208 Self::builder(url).config(config).build().await
209 }
210
211 pub fn builder(url: WebSocketClientUrl) -> Builder {
216 Builder {
217 url,
218 compat: Default::default(),
219 transport_config: Default::default(),
220 }
221 }
222
223 async fn perform_with_dialect<R, S>(&self, request: R, dialect: S) -> Result<R::Output, Error>
224 where
225 R: SimpleRequest<S>,
226 S: Dialect,
227 {
228 self.inner.perform(request, dialect).await
229 }
230}
231
232#[async_trait]
233impl Client for WebSocketClient {
234 async fn perform<R>(&self, request: R) -> Result<R::Output, Error>
235 where
236 R: SimpleRequest,
237 {
238 self.perform_with_dialect(request, LatestDialect).await
239 }
240
241 async fn block_results<H>(&self, height: H) -> Result<endpoint::block_results::Response, Error>
242 where
243 H: Into<Height> + Send,
244 {
245 perform_with_compat!(self, endpoint::block_results::Request::new(height.into()))
246 }
247
248 async fn latest_block_results(&self) -> Result<endpoint::block_results::Response, Error> {
249 perform_with_compat!(self, endpoint::block_results::Request::default())
250 }
251
252 async fn header<H>(&self, height: H) -> Result<endpoint::header::Response, Error>
253 where
254 H: Into<Height> + Send,
255 {
256 let height = height.into();
257 match self.compat {
258 CompatMode::V0_37 => self.perform(endpoint::header::Request::new(height)).await,
259 CompatMode::V0_34 => {
260 let resp = self
263 .perform_with_dialect(endpoint::block::Request::new(height), v0_34::Dialect)
264 .await?;
265 Ok(resp.into())
266 },
267 }
268 }
269
270 async fn header_by_hash(
271 &self,
272 hash: Hash,
273 ) -> Result<endpoint::header_by_hash::Response, Error> {
274 match self.compat {
275 CompatMode::V0_37 => {
276 self.perform(endpoint::header_by_hash::Request::new(hash))
277 .await
278 },
279 CompatMode::V0_34 => {
280 let resp = self
283 .perform_with_dialect(
284 endpoint::block_by_hash::Request::new(hash),
285 v0_34::Dialect,
286 )
287 .await?;
288 Ok(resp.into())
289 },
290 }
291 }
292
293 async fn tx(&self, hash: Hash, prove: bool) -> Result<endpoint::tx::Response, Error> {
294 perform_with_compat!(self, endpoint::tx::Request::new(hash, prove))
295 }
296
297 async fn tx_search(
298 &self,
299 query: Query,
300 prove: bool,
301 page: u32,
302 per_page: u8,
303 order: Order,
304 ) -> Result<endpoint::tx_search::Response, Error> {
305 perform_with_compat!(
306 self,
307 endpoint::tx_search::Request::new(query, prove, page, per_page, order)
308 )
309 }
310
311 async fn broadcast_tx_commit<T>(
312 &self,
313 tx: T,
314 ) -> Result<endpoint::broadcast::tx_commit::Response, Error>
315 where
316 T: Into<Vec<u8>> + Send,
317 {
318 perform_with_compat!(self, endpoint::broadcast::tx_commit::Request::new(tx))
319 }
320}
321
322#[async_trait]
323impl SubscriptionClient for WebSocketClient {
324 async fn subscribe(&self, query: Query) -> Result<Subscription, Error> {
325 self.inner.subscribe(query).await
326 }
327
328 async fn unsubscribe(&self, query: Query) -> Result<(), Error> {
329 self.inner.unsubscribe(query).await
330 }
331
332 fn close(self) -> Result<(), Error> {
333 self.inner.close()
334 }
335}
336
337#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
341#[serde(transparent)]
342pub struct WebSocketClientUrl(Url);
343
344impl TryFrom<Url> for WebSocketClientUrl {
345 type Error = Error;
346
347 fn try_from(value: Url) -> Result<Self, Error> {
348 match value.scheme() {
349 Scheme::WebSocket | Scheme::SecureWebSocket => Ok(Self(value)),
350 _ => Err(Error::invalid_params(format!(
351 "cannot use URL {value} with WebSocket clients"
352 ))),
353 }
354 }
355}
356
357impl FromStr for WebSocketClientUrl {
358 type Err = Error;
359
360 fn from_str(s: &str) -> Result<Self, Error> {
361 let url: Url = s.parse()?;
362 url.try_into()
363 }
364}
365
366impl fmt::Display for WebSocketClientUrl {
367 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368 self.0.fmt(f)
369 }
370}
371
372impl TryFrom<&str> for WebSocketClientUrl {
373 type Error = Error;
374
375 fn try_from(value: &str) -> Result<Self, Error> {
376 value.parse()
377 }
378}
379
380impl TryFrom<net::Address> for WebSocketClientUrl {
381 type Error = Error;
382
383 fn try_from(value: net::Address) -> Result<Self, Error> {
384 match value {
385 net::Address::Tcp {
386 peer_id: _,
387 host,
388 port,
389 } => format!("ws://{host}:{port}/websocket").parse(),
390 net::Address::Unix { .. } => Err(Error::invalid_params(
391 "only TCP-based node addresses are supported".to_string(),
392 )),
393 }
394 }
395}
396
397impl From<WebSocketClientUrl> for Url {
398 fn from(url: WebSocketClientUrl) -> Self {
399 url.0
400 }
401}
402
403mod sealed {
404 use async_tungstenite::{
405 tokio::{connect_async_with_config, connect_async_with_tls_connector_and_config},
406 tungstenite::client::IntoClientRequest,
407 };
408 use tracing::debug;
409
410 use super::{
411 DriverCommand, SimpleRequestCommand, SubscribeCommand, UnsubscribeCommand,
412 WebSocketClientDriver, WebSocketConfig,
413 };
414 use crate::{
415 client::{
416 sync::{unbounded, ChannelTx},
417 transport::auth::authorize,
418 CompatMode,
419 },
420 dialect::Dialect,
421 prelude::*,
422 query::Query,
423 request::Wrapper,
424 utils::uuid_str,
425 Error, Response, SimpleRequest, Subscription, Url,
426 };
427
428 #[derive(Debug, Clone)]
431 pub struct Unsecure;
432
433 #[derive(Debug, Clone)]
436 pub struct Secure;
437
438 #[derive(Debug, Clone)]
445 pub struct AsyncTungsteniteClient<C> {
446 cmd_tx: ChannelTx<DriverCommand>,
447 _client_type: core::marker::PhantomData<C>,
448 }
449
450 impl AsyncTungsteniteClient<Unsecure> {
451 pub async fn new(
460 url: Url,
461 compat: CompatMode,
462 config: Option<WebSocketConfig>,
463 ) -> Result<(Self, WebSocketClientDriver), Error> {
464 debug!("Connecting to unsecure WebSocket endpoint: {}", url);
465
466 let (stream, _response) = connect_async_with_config(url, config)
467 .await
468 .map_err(Error::tungstenite)?;
469
470 let (cmd_tx, cmd_rx) = unbounded();
471 let driver = WebSocketClientDriver::new(stream, cmd_rx, compat);
472 let client = Self {
473 cmd_tx,
474 _client_type: Default::default(),
475 };
476
477 Ok((client, driver))
478 }
479 }
480
481 impl AsyncTungsteniteClient<Secure> {
482 pub async fn new(
492 url: Url,
493 compat: CompatMode,
494 config: Option<WebSocketConfig>,
495 ) -> Result<(Self, WebSocketClientDriver), Error> {
496 debug!("Connecting to secure WebSocket endpoint: {}", url);
497
498 let (stream, _response) =
501 connect_async_with_tls_connector_and_config(url, None, config)
502 .await
503 .map_err(Error::tungstenite)?;
504
505 let (cmd_tx, cmd_rx) = unbounded();
506 let driver = WebSocketClientDriver::new(stream, cmd_rx, compat);
507 let client = Self {
508 cmd_tx,
509 _client_type: Default::default(),
510 };
511
512 Ok((client, driver))
513 }
514 }
515
516 impl<C> AsyncTungsteniteClient<C> {
517 fn send_cmd(&self, cmd: DriverCommand) -> Result<(), Error> {
518 self.cmd_tx.send(cmd)
519 }
520
521 pub fn close(self) -> Result<(), Error> {
523 self.send_cmd(DriverCommand::Terminate)
524 }
525 }
526
527 impl<C> AsyncTungsteniteClient<C> {
528 pub async fn perform<R, S>(&self, request: R) -> Result<R::Output, Error>
529 where
530 R: SimpleRequest<S>,
531 S: Dialect,
532 {
533 let wrapper = Wrapper::new(request);
534 let id = wrapper.id().to_string();
535 let wrapped_request = wrapper.into_json();
536
537 tracing::debug!("Outgoing request: {}", wrapped_request);
538
539 let (response_tx, mut response_rx) = unbounded();
540
541 self.send_cmd(DriverCommand::SimpleRequest(SimpleRequestCommand {
542 id,
543 wrapped_request,
544 response_tx,
545 }))?;
546
547 let response = response_rx.recv().await.ok_or_else(|| {
548 Error::client_internal("failed to hear back from WebSocket driver".to_string())
549 })??;
550
551 tracing::debug!("Incoming response: {}", response);
552
553 R::Response::from_string(response).map(Into::into)
554 }
555
556 pub async fn subscribe(&self, query: Query) -> Result<Subscription, Error> {
557 let (subscription_tx, subscription_rx) = unbounded();
558 let (response_tx, mut response_rx) = unbounded();
559 let id = uuid_str();
561 self.send_cmd(DriverCommand::Subscribe(SubscribeCommand {
562 id: id.to_string(),
563 query: query.to_string(),
564 subscription_tx,
565 response_tx,
566 }))?;
567 response_rx.recv().await.ok_or_else(|| {
569 Error::client_internal("failed to hear back from WebSocket driver".to_string())
570 })??;
571 Ok(Subscription::new(id, query, subscription_rx))
572 }
573
574 pub async fn unsubscribe(&self, query: Query) -> Result<(), Error> {
575 let (response_tx, mut response_rx) = unbounded();
576 self.send_cmd(DriverCommand::Unsubscribe(UnsubscribeCommand {
577 query: query.to_string(),
578 response_tx,
579 }))?;
580 response_rx.recv().await.ok_or_else(|| {
581 Error::client_internal("failed to hear back from WebSocket driver".to_string())
582 })??;
583 Ok(())
584 }
585 }
586
587 #[derive(Debug, Clone)]
590 pub enum WebSocketClient {
591 Unsecure(AsyncTungsteniteClient<Unsecure>),
592 Secure(AsyncTungsteniteClient<Secure>),
593 }
594
595 impl WebSocketClient {
596 pub async fn new_unsecure(
597 url: Url,
598 compat: CompatMode,
599 config: Option<WebSocketConfig>,
600 ) -> Result<(Self, WebSocketClientDriver), Error> {
601 let (client, driver) =
602 AsyncTungsteniteClient::<Unsecure>::new(url, compat, config).await?;
603 Ok((Self::Unsecure(client), driver))
604 }
605
606 pub async fn new_secure(
607 url: Url,
608 compat: CompatMode,
609 config: Option<WebSocketConfig>,
610 ) -> Result<(Self, WebSocketClientDriver), Error> {
611 let (client, driver) =
612 AsyncTungsteniteClient::<Secure>::new(url, compat, config).await?;
613 Ok((Self::Secure(client), driver))
614 }
615
616 pub fn close(self) -> Result<(), Error> {
617 match self {
618 WebSocketClient::Unsecure(c) => c.close(),
619 WebSocketClient::Secure(c) => c.close(),
620 }
621 }
622 }
623
624 impl WebSocketClient {
625 pub async fn perform<R, S>(&self, request: R, _dialect: S) -> Result<R::Output, Error>
626 where
627 R: SimpleRequest<S>,
628 S: Dialect,
629 {
630 match self {
631 WebSocketClient::Unsecure(c) => c.perform(request).await,
632 WebSocketClient::Secure(c) => c.perform(request).await,
633 }
634 }
635
636 pub async fn subscribe(&self, query: Query) -> Result<Subscription, Error> {
637 match self {
638 WebSocketClient::Unsecure(c) => c.subscribe(query).await,
639 WebSocketClient::Secure(c) => c.subscribe(query).await,
640 }
641 }
642
643 pub async fn unsubscribe(&self, query: Query) -> Result<(), Error> {
644 match self {
645 WebSocketClient::Unsecure(c) => c.unsubscribe(query).await,
646 WebSocketClient::Secure(c) => c.unsubscribe(query).await,
647 }
648 }
649 }
650
651 use async_tungstenite::tungstenite;
652
653 impl IntoClientRequest for Url {
654 fn into_client_request(
655 self,
656 ) -> tungstenite::Result<tungstenite::handshake::client::Request> {
657 let builder = tungstenite::handshake::client::Request::builder()
658 .method("GET")
659 .header("Host", self.host())
660 .header("Connection", "Upgrade")
661 .header("Upgrade", "websocket")
662 .header("Sec-WebSocket-Version", "13")
663 .header(
664 "Sec-WebSocket-Key",
665 tungstenite::handshake::client::generate_key(),
666 );
667
668 let builder = if let Some(auth) = authorize(self.as_ref()) {
669 builder.header("Authorization", auth.to_string())
670 } else {
671 builder
672 };
673
674 builder
675 .uri(self.to_string())
676 .body(())
677 .map_err(tungstenite::error::Error::HttpFormat)
678 }
679 }
680}
681
682#[derive(Debug, Clone)]
685enum DriverCommand {
686 Subscribe(SubscribeCommand),
688 Unsubscribe(UnsubscribeCommand),
690 SimpleRequest(SimpleRequestCommand),
692 Terminate,
693}
694
695#[derive(Debug, Clone)]
696struct SubscribeCommand {
697 id: String,
699 query: String,
701 subscription_tx: SubscriptionTx,
703 response_tx: ChannelTx<Result<(), Error>>,
705}
706
707#[derive(Debug, Clone)]
708struct UnsubscribeCommand {
709 query: String,
711 response_tx: ChannelTx<Result<(), Error>>,
713}
714
715#[derive(Debug, Clone)]
716struct SimpleRequestCommand {
717 id: String,
721 wrapped_request: String,
723 response_tx: ChannelTx<Result<String, Error>>,
725}
726
727#[derive(Serialize, Deserialize, Debug, Clone)]
728struct GenericJsonResponse(serde_json::Value);
729
730impl Response for GenericJsonResponse {}
731
732pub struct WebSocketClientDriver {
737 stream: WebSocketStream<ConnectStream>,
739 router: SubscriptionRouter,
741 cmd_rx: ChannelRx<DriverCommand>,
743 pending_commands: HashMap<SubscriptionId, DriverCommand>,
746 compat: CompatMode,
748}
749
750impl WebSocketClientDriver {
751 fn new(
752 stream: WebSocketStream<ConnectStream>,
753 cmd_rx: ChannelRx<DriverCommand>,
754 compat: CompatMode,
755 ) -> Self {
756 Self {
757 stream,
758 router: SubscriptionRouter::default(),
759 cmd_rx,
760 pending_commands: HashMap::new(),
761 compat,
762 }
763 }
764
765 async fn send_msg(&mut self, msg: Message) -> Result<(), Error> {
766 self.stream.send(msg).await.map_err(|e| {
767 Error::web_socket("failed to write to WebSocket connection".to_string(), e)
768 })
769 }
770
771 async fn simple_request(&mut self, cmd: SimpleRequestCommand) -> Result<(), Error> {
772 if let Err(e) = self
773 .send_msg(Message::Text(cmd.wrapped_request.clone()))
774 .await
775 {
776 cmd.response_tx.send(Err(e.clone()))?;
777 return Err(e);
778 }
779 self.pending_commands
780 .insert(cmd.id.clone(), DriverCommand::SimpleRequest(cmd));
781 Ok(())
782 }
783
784 pub async fn run(mut self) -> Result<(), Error> {
787 let mut ping_interval =
788 tokio::time::interval_at(Instant::now().add(PING_INTERVAL), PING_INTERVAL);
789
790 let recv_timeout = tokio::time::sleep(RECV_TIMEOUT);
791 tokio::pin!(recv_timeout);
792
793 loop {
794 tokio::select! {
795 Some(res) = self.stream.next() => match res {
796 Ok(msg) => {
797 recv_timeout.as_mut().reset(Instant::now().add(RECV_TIMEOUT));
800 self.handle_incoming_msg(msg).await?
801 },
802 Err(e) => return Err(
803 Error::web_socket(
804 "failed to read from WebSocket connection".to_string(),
805 e
806 ),
807 ),
808 },
809 Some(cmd) = self.cmd_rx.recv() => match cmd {
810 DriverCommand::Subscribe(subs_cmd) => self.subscribe(subs_cmd).await?,
811 DriverCommand::Unsubscribe(unsubs_cmd) => self.unsubscribe(unsubs_cmd).await?,
812 DriverCommand::SimpleRequest(req_cmd) => self.simple_request(req_cmd).await?,
813 DriverCommand::Terminate => return self.close().await,
814 },
815 _ = ping_interval.tick() => self.ping().await?,
816 _ = &mut recv_timeout => {
817 return Err(Error::web_socket_timeout(RECV_TIMEOUT));
818 }
819 }
820 }
821 }
822
823 async fn send_request<R>(&mut self, wrapper: Wrapper<R>) -> Result<(), Error>
824 where
825 R: Request,
826 {
827 self.send_msg(Message::Text(
828 serde_json::to_string_pretty(&wrapper).unwrap(),
829 ))
830 .await
831 }
832
833 async fn subscribe(&mut self, cmd: SubscribeCommand) -> Result<(), Error> {
834 if self.router.num_subscriptions_for_query(cmd.query.clone()) > 0 {
838 let (id, query, subscription_tx, response_tx) =
839 (cmd.id, cmd.query, cmd.subscription_tx, cmd.response_tx);
840 self.router.add(id, query, subscription_tx);
841 return response_tx.send(Ok(()));
842 }
843
844 let wrapper = Wrapper::new_with_id(
846 Id::Str(cmd.id.clone()),
847 subscribe::Request::new(cmd.query.clone()),
848 );
849 if let Err(e) = self.send_request(wrapper).await {
850 cmd.response_tx.send(Err(e.clone()))?;
851 return Err(e);
852 }
853 self.pending_commands
854 .insert(cmd.id.clone(), DriverCommand::Subscribe(cmd));
855 Ok(())
856 }
857
858 async fn unsubscribe(&mut self, cmd: UnsubscribeCommand) -> Result<(), Error> {
859 if self.router.remove_by_query(cmd.query.clone()) == 0 {
863 cmd.response_tx.send(Ok(()))?;
866 return Ok(());
867 }
868
869 let wrapper = Wrapper::new(unsubscribe::Request::new(cmd.query.clone()));
872 let req_id = wrapper.id().clone();
873 if let Err(e) = self.send_request(wrapper).await {
874 cmd.response_tx.send(Err(e.clone()))?;
875 return Err(e);
876 }
877 self.pending_commands
878 .insert(req_id.to_string(), DriverCommand::Unsubscribe(cmd));
879 Ok(())
880 }
881
882 async fn handle_incoming_msg(&mut self, msg: Message) -> Result<(), Error> {
883 match msg {
884 Message::Text(s) => self.handle_text_msg(s).await,
885 Message::Ping(v) => self.pong(v).await,
886 _ => Ok(()),
887 }
888 }
889
890 async fn handle_text_msg(&mut self, msg: String) -> Result<(), Error> {
891 let parse_res = match self.compat {
892 CompatMode::V0_37 => event::v1::DeEvent::from_string(&msg).map(Into::into),
893 CompatMode::V0_34 => event::v0_34::DeEvent::from_string(&msg).map(Into::into),
894 };
895 if let Ok(ev) = parse_res {
896 debug!("JSON-RPC event: {}", msg);
897 self.publish_event(ev).await;
898 return Ok(());
899 }
900
901 let wrapper: response::Wrapper<GenericJsonResponse> = match serde_json::from_str(&msg) {
902 Ok(w) => w,
903 Err(e) => {
904 error!(
905 "Failed to deserialize incoming message as a JSON-RPC message: {}",
906 e
907 );
908
909 debug!("JSON-RPC message: {}", msg);
910
911 return Ok(());
912 },
913 };
914
915 debug!("Generic JSON-RPC message: {:?}", wrapper);
916
917 let id = wrapper.id().to_string();
918
919 if let Some(e) = wrapper.into_error() {
920 self.publish_error(&id, e).await;
921 }
922
923 if let Some(pending_cmd) = self.pending_commands.remove(&id) {
924 self.respond_to_pending_command(pending_cmd, msg).await?;
925 };
926
927 Ok(())
931 }
932
933 async fn publish_error(&mut self, id: SubscriptionIdRef<'_>, err: Error) {
934 if let PublishResult::AllDisconnected(query) = self.router.publish_error(id, err) {
935 debug!(
936 "All subscribers for query \"{}\" have disconnected. Unsubscribing from query...",
937 query
938 );
939
940 if let Err(e) = self
944 .send_request(Wrapper::new(unsubscribe::Request::new(query)))
945 .await
946 {
947 error!("Failed to send unsubscribe request: {}", e);
948 }
949 }
950 }
951
952 async fn publish_event(&mut self, ev: Event) {
953 if let PublishResult::AllDisconnected(query) = self.router.publish_event(ev) {
954 debug!(
955 "All subscribers for query \"{}\" have disconnected. Unsubscribing from query...",
956 query
957 );
958
959 if let Err(e) = self
963 .send_request(Wrapper::new(unsubscribe::Request::new(query)))
964 .await
965 {
966 error!("Failed to send unsubscribe request: {}", e);
967 }
968 }
969 }
970
971 async fn respond_to_pending_command(
972 &mut self,
973 pending_cmd: DriverCommand,
974 response: String,
975 ) -> Result<(), Error> {
976 match pending_cmd {
977 DriverCommand::Subscribe(cmd) => {
978 let (id, query, subscription_tx, response_tx) =
979 (cmd.id, cmd.query, cmd.subscription_tx, cmd.response_tx);
980 self.router.add(id, query, subscription_tx);
981 response_tx.send(Ok(()))
982 },
983 DriverCommand::Unsubscribe(cmd) => cmd.response_tx.send(Ok(())),
984 DriverCommand::SimpleRequest(cmd) => cmd.response_tx.send(Ok(response)),
985 _ => Ok(()),
986 }
987 }
988
989 async fn pong(&mut self, v: Vec<u8>) -> Result<(), Error> {
990 self.send_msg(Message::Pong(v)).await
991 }
992
993 async fn ping(&mut self) -> Result<(), Error> {
994 self.send_msg(Message::Ping(Vec::new())).await
995 }
996
997 async fn close(mut self) -> Result<(), Error> {
998 self.send_msg(Message::Close(Some(CloseFrame {
999 code: CloseCode::Normal,
1000 reason: Cow::from("client closed WebSocket connection"),
1001 })))
1002 .await?;
1003
1004 while let Some(res) = self.stream.next().await {
1005 if res.is_err() {
1006 return Ok(());
1007 }
1008 }
1009 Ok(())
1010 }
1011}
1012
1013#[cfg(test)]
1014mod tests {
1015 use alloc::collections::BTreeMap as HashMap;
1016 use core::str::FromStr;
1017 use std::{path::PathBuf, println};
1018
1019 use async_tungstenite::{
1020 tokio::{accept_async, TokioAdapter},
1021 tungstenite::client::IntoClientRequest,
1022 };
1023 use cometbft_config::net;
1024 use futures::StreamExt;
1025 use http::{header::AUTHORIZATION, Uri};
1026 use tokio::{
1027 fs,
1028 net::{TcpListener, TcpStream},
1029 task::JoinHandle,
1030 };
1031
1032 use super::*;
1033 use crate::{client::sync::unbounded, event, query::EventType, request, Id, Method};
1034
1035 struct TestServer {
1037 node_addr: net::Address,
1038 driver_hdl: JoinHandle<Result<(), Error>>,
1039 terminate_tx: ChannelTx<Result<(), Error>>,
1040 event_tx: ChannelTx<Event>,
1041 }
1042
1043 #[derive(Copy, Clone)]
1046 enum TestRpcVersion {
1047 V0_34,
1048 V0_37,
1049 V0_38,
1050 }
1051
1052 impl TestServer {
1053 async fn new(addr: &str, version: TestRpcVersion) -> Self {
1054 let listener = TcpListener::bind(addr).await.unwrap();
1055 let local_addr = listener.local_addr().unwrap();
1056 let node_addr = net::Address::Tcp {
1057 peer_id: None,
1058 host: local_addr.ip().to_string(),
1059 port: local_addr.port(),
1060 };
1061 let (terminate_tx, terminate_rx) = unbounded();
1062 let (event_tx, event_rx) = unbounded();
1063 let driver = TestServerDriver::new(listener, version, event_rx, terminate_rx);
1064 let driver_hdl = tokio::spawn(async move { driver.run().await });
1065 Self {
1066 node_addr,
1067 driver_hdl,
1068 terminate_tx,
1069 event_tx,
1070 }
1071 }
1072
1073 fn publish_event(&mut self, ev: Event) -> Result<(), Error> {
1074 self.event_tx.send(ev)
1075 }
1076
1077 async fn terminate(self) -> Result<(), Error> {
1078 self.terminate_tx.send(Ok(())).unwrap();
1079 self.driver_hdl.await.unwrap()
1080 }
1081 }
1082
1083 struct TestServerDriver {
1085 listener: TcpListener,
1086 version: TestRpcVersion,
1087 event_rx: ChannelRx<Event>,
1088 terminate_rx: ChannelRx<Result<(), Error>>,
1089 handlers: Vec<TestServerHandler>,
1090 }
1091
1092 impl TestServerDriver {
1093 fn new(
1094 listener: TcpListener,
1095 version: TestRpcVersion,
1096 event_rx: ChannelRx<Event>,
1097 terminate_rx: ChannelRx<Result<(), Error>>,
1098 ) -> Self {
1099 Self {
1100 listener,
1101 version,
1102 event_rx,
1103 terminate_rx,
1104 handlers: Vec::new(),
1105 }
1106 }
1107
1108 async fn run(mut self) -> Result<(), Error> {
1109 loop {
1110 tokio::select! {
1111 Some(ev) = self.event_rx.recv() => self.publish_event(ev),
1112 res = self.listener.accept() => {
1113 let (stream, _) = res.unwrap();
1114 self.handle_incoming(stream).await
1115 }
1116 Some(res) = self.terminate_rx.recv() => {
1117 self.terminate().await;
1118 return res;
1119 },
1120 }
1121 }
1122 }
1123
1124 fn publish_event(&mut self, ev: Event) {
1127 for handler in &mut self.handlers {
1128 handler.publish_event(ev.clone());
1129 }
1130 }
1131
1132 async fn handle_incoming(&mut self, stream: TcpStream) {
1133 self.handlers
1134 .push(TestServerHandler::new(stream, self.version).await);
1135 }
1136
1137 async fn terminate(&mut self) {
1138 while !self.handlers.is_empty() {
1139 let handler = match self.handlers.pop() {
1140 Some(h) => h,
1141 None => break,
1142 };
1143 let _ = handler.terminate().await;
1144 }
1145 }
1146 }
1147
1148 struct TestServerHandler {
1151 driver_hdl: JoinHandle<Result<(), Error>>,
1152 terminate_tx: ChannelTx<Result<(), Error>>,
1153 event_tx: ChannelTx<Event>,
1154 }
1155
1156 impl TestServerHandler {
1157 async fn new(stream: TcpStream, version: TestRpcVersion) -> Self {
1158 let conn: WebSocketStream<TokioAdapter<TcpStream>> =
1159 accept_async(stream).await.unwrap();
1160 let (terminate_tx, terminate_rx) = unbounded();
1161 let (event_tx, event_rx) = unbounded();
1162 let driver = TestServerHandlerDriver::new(conn, version, event_rx, terminate_rx);
1163 let driver_hdl = tokio::spawn(async move { driver.run().await });
1164 Self {
1165 driver_hdl,
1166 terminate_tx,
1167 event_tx,
1168 }
1169 }
1170
1171 fn publish_event(&mut self, ev: Event) {
1172 let _ = self.event_tx.send(ev);
1173 }
1174
1175 async fn terminate(self) -> Result<(), Error> {
1176 self.terminate_tx.send(Ok(()))?;
1177 self.driver_hdl.await.unwrap()
1178 }
1179 }
1180
1181 struct TestServerHandlerDriver {
1183 conn: WebSocketStream<TokioAdapter<TcpStream>>,
1184 version: TestRpcVersion,
1185 event_rx: ChannelRx<Event>,
1186 terminate_rx: ChannelRx<Result<(), Error>>,
1187 subscriptions: HashMap<String, String>,
1190 }
1191
1192 impl TestServerHandlerDriver {
1193 fn new(
1194 conn: WebSocketStream<TokioAdapter<TcpStream>>,
1195 version: TestRpcVersion,
1196 event_rx: ChannelRx<Event>,
1197 terminate_rx: ChannelRx<Result<(), Error>>,
1198 ) -> Self {
1199 Self {
1200 conn,
1201 version,
1202 event_rx,
1203 terminate_rx,
1204 subscriptions: HashMap::new(),
1205 }
1206 }
1207
1208 async fn run(mut self) -> Result<(), Error> {
1209 loop {
1210 tokio::select! {
1211 Some(msg) = self.conn.next() => {
1212 if let Some(ret) = self.handle_incoming_msg(msg.unwrap()).await {
1213 return ret;
1214 }
1215 }
1216 Some(ev) = self.event_rx.recv() => self.publish_event(ev).await,
1217 Some(res) = self.terminate_rx.recv() => {
1218 self.terminate().await;
1219 return res;
1220 },
1221 }
1222 }
1223 }
1224
1225 async fn publish_event(&mut self, ev: Event) {
1226 let subs_id = match self.subscriptions.get(&ev.query) {
1227 Some(id) => Id::Str(id.clone()),
1228 None => return,
1229 };
1230 match self.version {
1231 TestRpcVersion::V0_38 => {
1232 let ev: event::v0_38::SerEvent = ev.into();
1233 self.send(subs_id, ev).await;
1234 },
1235 TestRpcVersion::V0_37 => {
1236 let ev: event::v0_37::SerEvent = ev.into();
1237 self.send(subs_id, ev).await;
1238 },
1239 TestRpcVersion::V0_34 => {
1240 let ev: event::v0_34::SerEvent = ev.into();
1241 self.send(subs_id, ev).await;
1242 },
1243 }
1244 }
1245
1246 async fn handle_incoming_msg(&mut self, msg: Message) -> Option<Result<(), Error>> {
1247 match msg {
1248 Message::Text(s) => self.handle_incoming_text_msg(s).await,
1249 Message::Ping(v) => {
1250 let _ = self.conn.send(Message::Pong(v)).await;
1251 None
1252 },
1253 Message::Close(_) => {
1254 self.terminate().await;
1255 Some(Ok(()))
1256 },
1257 _ => None,
1258 }
1259 }
1260
1261 async fn handle_incoming_text_msg(&mut self, msg: String) -> Option<Result<(), Error>> {
1262 match serde_json::from_str::<serde_json::Value>(&msg) {
1263 Ok(json_msg) => {
1264 if let Some(json_method) = json_msg.get("method") {
1265 match Method::from_str(json_method.as_str().unwrap()) {
1266 Ok(method) => match method {
1267 Method::Subscribe => {
1268 let req = serde_json::from_str::<
1269 request::Wrapper<subscribe::Request>,
1270 >(&msg)
1271 .unwrap();
1272
1273 self.add_subscription(
1274 req.params().query.clone(),
1275 req.id().to_string(),
1276 );
1277 self.send(req.id().clone(), subscribe::Response {}).await;
1278 },
1279 Method::Unsubscribe => {
1280 let req = serde_json::from_str::<
1281 request::Wrapper<unsubscribe::Request>,
1282 >(&msg)
1283 .unwrap();
1284
1285 self.remove_subscription(req.params().query.clone());
1286 self.send(req.id().clone(), unsubscribe::Response {}).await;
1287 },
1288 _ => {
1289 println!("Unsupported method in incoming request: {}", &method);
1290 },
1291 },
1292 Err(e) => {
1293 println!(
1294 "Unexpected method in incoming request: {json_method} ({e})"
1295 );
1296 },
1297 }
1298 }
1299 },
1300 Err(e) => {
1301 println!("Failed to parse incoming request: {} ({})", &msg, e);
1302 },
1303 }
1304 None
1305 }
1306
1307 fn add_subscription(&mut self, query: String, id: String) {
1308 println!("Adding subscription with ID {} for query: {}", &id, &query);
1309 self.subscriptions.insert(query, id);
1310 }
1311
1312 fn remove_subscription(&mut self, query: String) {
1313 if let Some(id) = self.subscriptions.remove(&query) {
1314 println!("Removed subscription {id} for query: {query}");
1315 }
1316 }
1317
1318 async fn send<R>(&mut self, id: Id, res: R)
1319 where
1320 R: Serialize,
1321 {
1322 self.conn
1323 .send(Message::Text(
1324 serde_json::to_string(&response::Wrapper::new_with_id(id, Some(res), None))
1325 .unwrap(),
1326 ))
1327 .await
1328 .unwrap();
1329 }
1330
1331 async fn terminate(&mut self) {
1332 let _ = self
1333 .conn
1334 .close(Some(CloseFrame {
1335 code: CloseCode::Normal,
1336 reason: Default::default(),
1337 }))
1338 .await;
1339 }
1340 }
1341
1342 async fn read_json_fixture(version: &str, name: &str) -> String {
1343 fs::read_to_string(
1344 PathBuf::from("./tests/kvstore_fixtures")
1345 .join(version)
1346 .join("incoming")
1347 .join(name.to_owned() + ".json"),
1348 )
1349 .await
1350 .unwrap()
1351 }
1352
1353 mod v0_34 {
1354 use super::*;
1355 use crate::event::v0_34::DeEvent;
1356
1357 async fn read_event(name: &str) -> Event {
1358 DeEvent::from_string(read_json_fixture("v0_34", name).await)
1359 .unwrap()
1360 .into()
1361 }
1362
1363 #[tokio::test]
1364 async fn websocket_client_happy_path() {
1365 let event1 = read_event("subscribe_newblock_0").await;
1366 let event2 = read_event("subscribe_newblock_1").await;
1367 let event3 = read_event("subscribe_newblock_2").await;
1368 let test_events = vec![event1, event2, event3];
1369
1370 println!("Starting WebSocket server...");
1371 let mut server = TestServer::new("127.0.0.1:0", TestRpcVersion::V0_34).await;
1372 println!("Creating client RPC WebSocket connection...");
1373 let url = server.node_addr.clone().try_into().unwrap();
1374 let (client, driver) = WebSocketClient::builder(url)
1375 .compat_mode(CompatMode::V0_34)
1376 .build()
1377 .await
1378 .unwrap();
1379 let driver_handle = tokio::spawn(async move { driver.run().await });
1380
1381 println!("Initiating subscription for new blocks...");
1382 let mut subs = client.subscribe(EventType::NewBlock.into()).await.unwrap();
1383
1384 let subs_collector_hdl = tokio::spawn(async move {
1386 let mut results = Vec::new();
1387 while let Some(res) = subs.next().await {
1388 results.push(res);
1389 if results.len() == 3 {
1390 break;
1391 }
1392 }
1393 results
1394 });
1395
1396 println!("Publishing events");
1397 for ev in &test_events {
1399 server.publish_event(ev.clone()).unwrap();
1400 }
1401
1402 println!("Collecting results from subscription...");
1403 let collected_results = subs_collector_hdl.await.unwrap();
1404
1405 client.close().unwrap();
1406 server.terminate().await.unwrap();
1407 let _ = driver_handle.await.unwrap();
1408 println!("Closed client and terminated server");
1409
1410 assert_eq!(3, collected_results.len());
1411 for i in 0..3 {
1412 assert_eq!(
1413 test_events[i],
1414 collected_results[i].as_ref().unwrap().clone()
1415 );
1416 }
1417 }
1418 }
1419
1420 mod v0_37 {
1421 use super::*;
1422 use crate::event::v0_37::DeEvent;
1423
1424 async fn read_event(name: &str) -> Event {
1425 DeEvent::from_string(read_json_fixture("v0_37", name).await)
1426 .unwrap()
1427 .into()
1428 }
1429
1430 #[tokio::test]
1431 async fn websocket_client_happy_path() {
1432 let event1 = read_event("subscribe_newblock_0").await;
1433 let event2 = read_event("subscribe_newblock_1").await;
1434 let event3 = read_event("subscribe_newblock_2").await;
1435 let test_events = vec![event1, event2, event3];
1436
1437 println!("Starting WebSocket server...");
1438 let mut server = TestServer::new("127.0.0.1:0", TestRpcVersion::V0_37).await;
1439 println!("Creating client RPC WebSocket connection...");
1440 let url = server.node_addr.clone().try_into().unwrap();
1441 let (client, driver) = WebSocketClient::builder(url)
1442 .compat_mode(CompatMode::V0_37)
1443 .build()
1444 .await
1445 .unwrap();
1446 let driver_handle = tokio::spawn(async move { driver.run().await });
1447
1448 println!("Initiating subscription for new blocks...");
1449 let mut subs = client.subscribe(EventType::NewBlock.into()).await.unwrap();
1450
1451 let subs_collector_hdl = tokio::spawn(async move {
1453 let mut results = Vec::new();
1454 while let Some(res) = subs.next().await {
1455 results.push(res);
1456 if results.len() == 3 {
1457 break;
1458 }
1459 }
1460 results
1461 });
1462
1463 println!("Publishing events");
1464 for ev in &test_events {
1466 server.publish_event(ev.clone()).unwrap();
1467 }
1468
1469 println!("Collecting results from subscription...");
1470 let collected_results = subs_collector_hdl.await.unwrap();
1471
1472 client.close().unwrap();
1473 server.terminate().await.unwrap();
1474 let _ = driver_handle.await.unwrap();
1475 println!("Closed client and terminated server");
1476
1477 assert_eq!(3, collected_results.len());
1478 for i in 0..3 {
1479 assert_eq!(
1480 test_events[i],
1481 collected_results[i].as_ref().unwrap().clone()
1482 );
1483 }
1484 }
1485 }
1486
1487 mod v0_38 {
1488 use super::*;
1489 use crate::event::v0_38::DeEvent;
1490
1491 async fn read_event(name: &str) -> Event {
1492 DeEvent::from_string(read_json_fixture("v0_38", name).await)
1493 .unwrap()
1494 .into()
1495 }
1496
1497 #[tokio::test]
1498 async fn websocket_client_happy_path() {
1499 let event1 = read_event("subscribe_newblock_0").await;
1500 let event2 = read_event("subscribe_newblock_1").await;
1501 let event3 = read_event("subscribe_newblock_2").await;
1502 let test_events = vec![event1, event2, event3];
1503
1504 println!("Starting WebSocket server...");
1505 let mut server = TestServer::new("127.0.0.1:0", TestRpcVersion::V0_38).await;
1506 println!("Creating client RPC WebSocket connection...");
1507 let url = server.node_addr.clone().try_into().unwrap();
1508 let (client, driver) = WebSocketClient::builder(url)
1509 .compat_mode(CompatMode::V0_37)
1510 .build()
1511 .await
1512 .unwrap();
1513 let driver_handle = tokio::spawn(async move { driver.run().await });
1514
1515 println!("Initiating subscription for new blocks...");
1516 let mut subs = client.subscribe(EventType::NewBlock.into()).await.unwrap();
1517
1518 let subs_collector_hdl = tokio::spawn(async move {
1520 let mut results = Vec::new();
1521 while let Some(res) = subs.next().await {
1522 results.push(res);
1523 if results.len() == 3 {
1524 break;
1525 }
1526 }
1527 results
1528 });
1529
1530 println!("Publishing events");
1531 for ev in &test_events {
1533 server.publish_event(ev.clone()).unwrap();
1534 }
1535
1536 println!("Collecting results from subscription...");
1537 let collected_results = subs_collector_hdl.await.unwrap();
1538
1539 client.close().unwrap();
1540 server.terminate().await.unwrap();
1541 let _ = driver_handle.await.unwrap();
1542 println!("Closed client and terminated server");
1543
1544 assert_eq!(3, collected_results.len());
1545 for i in 0..3 {
1546 assert_eq!(
1547 test_events[i],
1548 collected_results[i].as_ref().unwrap().clone()
1549 );
1550 }
1551 }
1552 }
1553
1554 fn authorization(req: &http::Request<()>) -> Option<&str> {
1555 req.headers()
1556 .get(AUTHORIZATION)
1557 .map(|h| h.to_str().unwrap())
1558 }
1559
1560 #[test]
1561 fn without_basic_auth() {
1562 let uri = Uri::from_str("http://example.com").unwrap();
1563 let req = uri.into_client_request().unwrap();
1564
1565 assert_eq!(authorization(&req), None);
1566 }
1567
1568 #[test]
1569 fn with_basic_auth() {
1570 let uri = Uri::from_str("http://toto:tata@example.com").unwrap();
1571 let req = uri.into_client_request().unwrap();
1572
1573 assert_eq!(authorization(&req), None);
1574 }
1575}