ethers_providers/rpc/transports/
legacy_ws.rs

1use crate::{
2    errors::ProviderError,
3    rpc::transports::common::{JsonRpcError, Params, Request, Response},
4    JsonRpcClient, PubsubClient,
5};
6
7use async_trait::async_trait;
8use ethers_core::types::U256;
9use futures_channel::{mpsc, oneshot};
10use futures_util::{
11    sink::{Sink, SinkExt},
12    stream::{Fuse, Stream, StreamExt},
13};
14use serde::{de::DeserializeOwned, Serialize};
15use serde_json::value::RawValue;
16use std::{
17    collections::{btree_map::Entry, BTreeMap},
18    fmt::{self, Debug},
19    sync::{
20        atomic::{AtomicU64, Ordering},
21        Arc,
22    },
23};
24use thiserror::Error;
25use tracing::trace;
26
27macro_rules! if_wasm {
28    ($($item:item)*) => {$(
29        #[cfg(target_arch = "wasm32")]
30        $item
31    )*}
32}
33
34macro_rules! if_not_wasm {
35    ($($item:item)*) => {$(
36        #[cfg(not(target_arch = "wasm32"))]
37        $item
38    )*}
39}
40
41if_wasm! {
42    use wasm_bindgen::prelude::*;
43    use wasm_bindgen_futures::spawn_local;
44    use ws_stream_wasm::*;
45
46    type Message = WsMessage;
47    type WsError = ws_stream_wasm::WsErr;
48    type WsStreamItem = Message;
49
50    macro_rules! error {
51        ( $( $t:tt )* ) => {
52            web_sys::console::error_1(&format!( $( $t )* ).into());
53        }
54    }
55    macro_rules! warn {
56        ( $( $t:tt )* ) => {
57            web_sys::console::warn_1(&format!( $( $t )* ).into());
58        }
59    }
60    macro_rules! debug {
61        ( $( $t:tt )* ) => {
62            web_sys::console::log_1(&format!( $( $t )* ).into());
63        }
64    }
65}
66
67if_not_wasm! {
68    use tokio_tungstenite::{
69        connect_async,
70        tungstenite::{
71            self,
72            protocol::CloseFrame,
73        },
74    };
75    type Message = tungstenite::protocol::Message;
76    type WsError = tungstenite::Error;
77    type WsStreamItem = Result<Message, WsError>;
78    use super::Authorization;
79    use tracing::{debug, error, warn};
80    use http::Request as HttpRequest;
81    use tungstenite::client::IntoClientRequest;
82}
83
84type Pending = oneshot::Sender<Result<Box<RawValue>, JsonRpcError>>;
85type Subscription = mpsc::UnboundedSender<Box<RawValue>>;
86
87/// Instructions for the `WsServer`.
88enum Instruction {
89    /// JSON-RPC request
90    Request { id: u64, request: String, sender: Pending },
91    /// Create a new subscription
92    Subscribe { id: U256, sink: Subscription },
93    /// Cancel an existing subscription
94    Unsubscribe { id: U256 },
95}
96
97/// A JSON-RPC Client over Websockets.
98///
99/// # Example
100///
101/// ```no_run
102/// # async fn foo() -> Result<(), Box<dyn std::error::Error>> {
103/// use ethers_providers::Ws;
104///
105/// let ws = Ws::connect("ws://localhost:8545").await?;
106/// # Ok(())
107/// # }
108/// ```
109#[derive(Clone)]
110pub struct Ws {
111    id: Arc<AtomicU64>,
112    instructions: mpsc::UnboundedSender<Instruction>,
113}
114
115impl Debug for Ws {
116    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117        f.debug_struct("WebsocketProvider").field("id", &self.id).finish()
118    }
119}
120
121impl Ws {
122    /// Initializes a new WebSocket Client, given a Stream/Sink Websocket implementer.
123    /// The websocket connection must be initiated separately.
124    pub fn new<S: 'static>(ws: S) -> Self
125    where
126        S: Send + Sync + Stream<Item = WsStreamItem> + Sink<Message, Error = WsError> + Unpin,
127    {
128        let (sink, stream) = mpsc::unbounded();
129        // Spawn the server
130        WsServer::new(ws, stream).spawn();
131
132        Self { id: Arc::new(AtomicU64::new(1)), instructions: sink }
133    }
134
135    /// Returns true if the WS connection is active, false otherwise
136    pub fn ready(&self) -> bool {
137        !self.instructions.is_closed()
138    }
139
140    /// Initializes a new WebSocket Client
141    #[cfg(target_arch = "wasm32")]
142    pub async fn connect(url: &str) -> Result<Self, ClientError> {
143        let (_, wsio) = WsMeta::connect(url, None).await.expect_throw("Could not create websocket");
144
145        Ok(Self::new(wsio))
146    }
147
148    /// Initializes a new WebSocket Client
149    #[cfg(not(target_arch = "wasm32"))]
150    pub async fn connect(url: impl IntoClientRequest + Unpin) -> Result<Self, ClientError> {
151        let (ws, _) = connect_async(url).await?;
152        Ok(Self::new(ws))
153    }
154
155    /// Initializes a new WebSocket Client with authentication
156    #[cfg(not(target_arch = "wasm32"))]
157    pub async fn connect_with_auth(
158        uri: impl IntoClientRequest + Unpin,
159        auth: Authorization,
160    ) -> Result<Self, ClientError> {
161        let mut request: HttpRequest<()> = uri.into_client_request()?;
162
163        let mut auth_value = http::HeaderValue::from_str(&auth.to_string())?;
164        auth_value.set_sensitive(true);
165
166        request.headers_mut().insert(http::header::AUTHORIZATION, auth_value);
167        Self::connect(request).await
168    }
169
170    fn send(&self, msg: Instruction) -> Result<(), ClientError> {
171        self.instructions.unbounded_send(msg).map_err(to_client_error)
172    }
173}
174
175#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
176#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
177impl JsonRpcClient for Ws {
178    type Error = ClientError;
179
180    async fn request<T: Serialize + Send + Sync, R: DeserializeOwned>(
181        &self,
182        method: &str,
183        params: T,
184    ) -> Result<R, ClientError> {
185        let next_id = self.id.fetch_add(1, Ordering::SeqCst);
186
187        // send the message
188        let (sender, receiver) = oneshot::channel();
189        let payload = Instruction::Request {
190            id: next_id,
191            request: serde_json::to_string(&Request::new(next_id, method, params))?,
192            sender,
193        };
194
195        // send the data
196        self.send(payload)?;
197
198        // wait for the response (the request itself may have errors as well)
199        let res = receiver.await??;
200
201        // parse it
202        Ok(serde_json::from_str(res.get())?)
203    }
204}
205
206impl PubsubClient for Ws {
207    type NotificationStream = mpsc::UnboundedReceiver<Box<RawValue>>;
208
209    fn subscribe<T: Into<U256>>(&self, id: T) -> Result<Self::NotificationStream, ClientError> {
210        let (sink, stream) = mpsc::unbounded();
211        self.send(Instruction::Subscribe { id: id.into(), sink })?;
212        Ok(stream)
213    }
214
215    fn unsubscribe<T: Into<U256>>(&self, id: T) -> Result<(), ClientError> {
216        self.send(Instruction::Unsubscribe { id: id.into() })
217    }
218}
219
220struct WsServer<S> {
221    ws: Fuse<S>,
222    instructions: Fuse<mpsc::UnboundedReceiver<Instruction>>,
223
224    pending: BTreeMap<u64, Pending>,
225    subscriptions: BTreeMap<U256, Subscription>,
226}
227
228impl<S> WsServer<S>
229where
230    S: Send + Sync + Stream<Item = WsStreamItem> + Sink<Message, Error = WsError> + Unpin,
231{
232    /// Instantiates the Websocket Server
233    fn new(ws: S, requests: mpsc::UnboundedReceiver<Instruction>) -> Self {
234        Self {
235            // Fuse the 2 steams together, so that we can `select` them in the
236            // Stream implementation
237            ws: ws.fuse(),
238            instructions: requests.fuse(),
239            pending: BTreeMap::default(),
240            subscriptions: BTreeMap::default(),
241        }
242    }
243
244    /// Returns whether the all work has been completed.
245    ///
246    /// If this method returns `true`, then the `instructions` channel has been closed and all
247    /// pending requests and subscriptions have been completed.
248    fn is_done(&self) -> bool {
249        self.instructions.is_done() && self.pending.is_empty() && self.subscriptions.is_empty()
250    }
251
252    /// Spawns the event loop
253    fn spawn(mut self)
254    where
255        S: 'static,
256    {
257        let f = async move {
258            loop {
259                if self.is_done() {
260                    debug!("work complete");
261                    break
262                }
263
264                if let Err(e) = self.tick().await {
265                    error!("Received a WebSocket error: {:?}", e);
266                    self.close_all_subscriptions();
267                    break
268                }
269            }
270        };
271
272        #[cfg(target_arch = "wasm32")]
273        spawn_local(f);
274
275        #[cfg(not(target_arch = "wasm32"))]
276        tokio::spawn(f);
277    }
278
279    // This will close all active subscriptions. Each process listening for
280    // updates will observe the end of their subscription streams.
281    fn close_all_subscriptions(&self) {
282        error!("Tearing down subscriptions");
283        for (_, sub) in self.subscriptions.iter() {
284            sub.close_channel();
285        }
286    }
287
288    // dispatch an RPC request
289    async fn service_request(
290        &mut self,
291        id: u64,
292        request: String,
293        sender: Pending,
294    ) -> Result<(), ClientError> {
295        if self.pending.insert(id, sender).is_some() {
296            warn!("Replacing a pending request with id {:?}", id);
297        }
298
299        if let Err(e) = self.ws.send(Message::Text(request)).await {
300            error!("WS connection error: {:?}", e);
301            self.pending.remove(&id);
302        }
303        Ok(())
304    }
305
306    /// Dispatch a subscription request
307    async fn service_subscribe(&mut self, id: U256, sink: Subscription) -> Result<(), ClientError> {
308        if self.subscriptions.insert(id, sink).is_some() {
309            warn!("Replacing already-registered subscription with id {:?}", id);
310        }
311        Ok(())
312    }
313
314    /// Dispatch a unsubscribe request
315    async fn service_unsubscribe(&mut self, id: U256) -> Result<(), ClientError> {
316        if self.subscriptions.remove(&id).is_none() {
317            warn!("Unsubscribing from non-existent subscription with id {:?}", id);
318        }
319        Ok(())
320    }
321
322    /// Dispatch an outgoing message
323    async fn service(&mut self, instruction: Instruction) -> Result<(), ClientError> {
324        match instruction {
325            Instruction::Request { id, request, sender } => {
326                self.service_request(id, request, sender).await
327            }
328            Instruction::Subscribe { id, sink } => self.service_subscribe(id, sink).await,
329            Instruction::Unsubscribe { id } => self.service_unsubscribe(id).await,
330        }
331    }
332
333    #[cfg(not(target_arch = "wasm32"))]
334    async fn handle_ping(&mut self, inner: Vec<u8>) -> Result<(), ClientError> {
335        self.ws.send(Message::Pong(inner)).await?;
336        Ok(())
337    }
338
339    async fn handle_text(&mut self, inner: String) -> Result<(), ClientError> {
340        trace!(msg=?inner, "received message");
341        let (id, result) = match serde_json::from_str(&inner)? {
342            Response::Success { id, result } => (id, Ok(result.to_owned())),
343            Response::Error { id, error } => (id, Err(error)),
344            Response::Notification { params, .. } => return self.handle_notification(params),
345        };
346
347        if let Some(request) = self.pending.remove(&id) {
348            if !request.is_canceled() {
349                request.send(result).map_err(to_client_error)?;
350            }
351        }
352
353        Ok(())
354    }
355
356    fn handle_notification(&mut self, params: Params<'_>) -> Result<(), ClientError> {
357        let id = params.subscription;
358        if let Entry::Occupied(stream) = self.subscriptions.entry(id) {
359            if let Err(err) = stream.get().unbounded_send(params.result.to_owned()) {
360                if err.is_disconnected() {
361                    // subscription channel was closed on the receiver end
362                    stream.remove();
363                }
364                return Err(to_client_error(err))
365            }
366        }
367
368        Ok(())
369    }
370
371    #[cfg(target_arch = "wasm32")]
372    async fn handle(&mut self, resp: Message) -> Result<(), ClientError> {
373        match resp {
374            Message::Text(inner) => self.handle_text(inner).await,
375            Message::Binary(buf) => Err(ClientError::UnexpectedBinary(buf)),
376        }
377    }
378
379    #[cfg(not(target_arch = "wasm32"))]
380    async fn handle(&mut self, resp: Message) -> Result<(), ClientError> {
381        match resp {
382            Message::Text(inner) => self.handle_text(inner).await,
383            Message::Frame(_) => Ok(()), // Server is allowed to send Raw frames
384            Message::Ping(inner) => self.handle_ping(inner).await,
385            Message::Pong(_) => Ok(()), // Server is allowed to send unsolicited pongs.
386            Message::Close(Some(frame)) => Err(ClientError::WsClosed(frame)),
387            Message::Close(None) => Err(ClientError::UnexpectedClose),
388            Message::Binary(buf) => Err(ClientError::UnexpectedBinary(buf)),
389        }
390    }
391
392    /// Processes 1 instruction or 1 incoming websocket message
393    #[allow(clippy::single_match)]
394    #[cfg(target_arch = "wasm32")]
395    async fn tick(&mut self) -> Result<(), ClientError> {
396        futures_util::select! {
397            // Handle requests
398            instruction = self.instructions.select_next_some() => {
399                self.service(instruction).await?;
400            },
401            // Handle ws messages
402            resp = self.ws.next() => match resp {
403                Some(resp) => self.handle(resp).await?,
404                None => {
405                    return Err(ClientError::UnexpectedClose);
406                },
407            }
408        };
409
410        Ok(())
411    }
412
413    /// Processes 1 instruction or 1 incoming websocket message
414    #[allow(clippy::single_match)]
415    #[cfg(not(target_arch = "wasm32"))]
416    async fn tick(&mut self) -> Result<(), ClientError> {
417        futures_util::select! {
418            // Handle requests
419            instruction = self.instructions.select_next_some() => {
420                self.service(instruction).await?;
421            },
422            // Handle ws messages
423            resp = self.ws.next() => match resp {
424                Some(Ok(resp)) => self.handle(resp).await?,
425                Some(Err(err)) => {
426                    tracing::error!(?err);
427                    return Err(ClientError::UnexpectedClose);
428                }
429                None => {
430                    return Err(ClientError::UnexpectedClose);
431                },
432            }
433        };
434
435        Ok(())
436    }
437}
438
439// TrySendError is private :(
440fn to_client_error<T: Debug>(err: T) -> ClientError {
441    ClientError::ChannelError(format!("{err:?}"))
442}
443
444/// Error thrown when sending a WS message
445#[derive(Debug, Error)]
446pub enum ClientError {
447    /// Thrown if deserialization failed
448    #[error(transparent)]
449    JsonError(#[from] serde_json::Error),
450
451    #[error(transparent)]
452    /// Thrown if the response could not be parsed
453    JsonRpcError(#[from] JsonRpcError),
454
455    /// Thrown if the websocket responds with binary data
456    #[error("Websocket responded with unexpected binary data")]
457    UnexpectedBinary(Vec<u8>),
458
459    /// Thrown if there's an error over the WS connection
460    #[error(transparent)]
461    TungsteniteError(#[from] WsError),
462
463    #[error("{0}")]
464    /// Error in internal mpsc channel
465    ChannelError(String),
466
467    #[error("{0}")]
468    /// Error in internal oneshot channel
469    Canceled(#[from] oneshot::Canceled),
470
471    /// Remote server sent a Close message
472    #[error("Websocket closed with info: {0:?}")]
473    #[cfg(not(target_arch = "wasm32"))]
474    WsClosed(CloseFrame<'static>),
475
476    /// Remote server sent a Close message
477    #[error("Websocket closed")]
478    #[cfg(target_arch = "wasm32")]
479    WsClosed,
480
481    /// Something caused the websocket to close
482    #[error("WebSocket connection closed unexpectedly")]
483    UnexpectedClose,
484
485    /// Could not create an auth header for websocket handshake
486    #[error(transparent)]
487    #[cfg(not(target_arch = "wasm32"))]
488    WsAuth(#[from] http::header::InvalidHeaderValue),
489
490    /// Unable to create a valid Uri
491    #[error(transparent)]
492    #[cfg(not(target_arch = "wasm32"))]
493    UriError(#[from] http::uri::InvalidUri),
494
495    /// Unable to create a valid Request
496    #[error(transparent)]
497    #[cfg(not(target_arch = "wasm32"))]
498    RequestError(#[from] http::Error),
499}
500
501impl crate::RpcError for ClientError {
502    fn as_error_response(&self) -> Option<&super::JsonRpcError> {
503        if let ClientError::JsonRpcError(err) = self {
504            Some(err)
505        } else {
506            None
507        }
508    }
509
510    fn as_serde_error(&self) -> Option<&serde_json::Error> {
511        match self {
512            ClientError::JsonError(err) => Some(err),
513            _ => None,
514        }
515    }
516}
517
518impl From<ClientError> for ProviderError {
519    fn from(src: ClientError) -> Self {
520        ProviderError::JsonRpcClientError(Box::new(src))
521    }
522}
523
524#[cfg(all(test, not(target_arch = "wasm32")))]
525mod tests {
526    use super::*;
527    use ethers_core::utils::Anvil;
528
529    #[tokio::test]
530    async fn request() {
531        let anvil = Anvil::new().block_time(1u64).spawn();
532        let ws = Ws::connect(anvil.ws_endpoint()).await.unwrap();
533
534        let block_num: U256 = ws.request("eth_blockNumber", ()).await.unwrap();
535        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
536        let block_num2: U256 = ws.request("eth_blockNumber", ()).await.unwrap();
537        assert!(block_num2 > block_num);
538    }
539
540    #[tokio::test]
541    #[cfg(not(feature = "celo"))]
542    async fn subscription() {
543        use ethers_core::types::{Block, TxHash};
544
545        let anvil = Anvil::new().block_time(1u64).spawn();
546        let ws = Ws::connect(anvil.ws_endpoint()).await.unwrap();
547
548        // Subscribing requires sending the sub request and then subscribing to
549        // the returned sub_id
550        let sub_id: U256 = ws.request("eth_subscribe", ["newHeads"]).await.unwrap();
551        let stream = ws.subscribe(sub_id).unwrap();
552
553        let blocks: Vec<u64> = stream
554            .take(3)
555            .map(|item| {
556                let block: Block<TxHash> = serde_json::from_str(item.get()).unwrap();
557                block.number.unwrap_or_default().as_u64()
558            })
559            .collect()
560            .await;
561        assert_eq!(blocks, vec![1, 2, 3]);
562    }
563
564    #[tokio::test]
565    async fn deserialization_fails() {
566        let anvil = Anvil::new().block_time(1u64).spawn();
567        let (ws, _) = tokio_tungstenite::connect_async(anvil.ws_endpoint()).await.unwrap();
568        let malformed_data = String::from("not a valid message");
569        let (_, stream) = mpsc::unbounded();
570        let resp = WsServer::new(ws, stream).handle_text(malformed_data).await;
571        resp.unwrap_err();
572    }
573}
574
575impl crate::Provider<Ws> {
576    /// Direct connection to a websocket endpoint
577    #[cfg(not(target_arch = "wasm32"))]
578    pub async fn connect(
579        url: impl tokio_tungstenite::tungstenite::client::IntoClientRequest + Unpin,
580    ) -> Result<Self, ProviderError> {
581        let ws = crate::Ws::connect(url).await?;
582        Ok(Self::new(ws))
583    }
584
585    /// Direct connection to a websocket endpoint
586    #[cfg(target_arch = "wasm32")]
587    pub async fn connect(url: &str) -> Result<Self, ProviderError> {
588        let ws = crate::Ws::connect(url).await?;
589        Ok(Self::new(ws))
590    }
591
592    /// Connect to a WS RPC provider with authentication details
593    #[cfg(not(target_arch = "wasm32"))]
594    pub async fn connect_with_auth(
595        url: impl tokio_tungstenite::tungstenite::client::IntoClientRequest + Unpin,
596        auth: Authorization,
597    ) -> Result<Self, ProviderError> {
598        let ws = crate::Ws::connect_with_auth(url, auth).await?;
599        Ok(Self::new(ws))
600    }
601}