1use crate::{
2 errors::ProviderError,
3 rpc::transports::common::{JsonRpcError, Params, Request, Response},
4 JsonRpcClient, PubsubClient,
5};
6
7use async_trait::async_trait;
8use ethers_core::types::U256;
9use futures_channel::{mpsc, oneshot};
10use futures_util::{
11 sink::{Sink, SinkExt},
12 stream::{Fuse, Stream, StreamExt},
13};
14use serde::{de::DeserializeOwned, Serialize};
15use serde_json::value::RawValue;
16use std::{
17 collections::{btree_map::Entry, BTreeMap},
18 fmt::{self, Debug},
19 sync::{
20 atomic::{AtomicU64, Ordering},
21 Arc,
22 },
23};
24use thiserror::Error;
25use tracing::trace;
26
27macro_rules! if_wasm {
28 ($($item:item)*) => {$(
29 #[cfg(target_arch = "wasm32")]
30 $item
31 )*}
32}
33
34macro_rules! if_not_wasm {
35 ($($item:item)*) => {$(
36 #[cfg(not(target_arch = "wasm32"))]
37 $item
38 )*}
39}
40
41if_wasm! {
42 use wasm_bindgen::prelude::*;
43 use wasm_bindgen_futures::spawn_local;
44 use ws_stream_wasm::*;
45
46 type Message = WsMessage;
47 type WsError = ws_stream_wasm::WsErr;
48 type WsStreamItem = Message;
49
50 macro_rules! error {
51 ( $( $t:tt )* ) => {
52 web_sys::console::error_1(&format!( $( $t )* ).into());
53 }
54 }
55 macro_rules! warn {
56 ( $( $t:tt )* ) => {
57 web_sys::console::warn_1(&format!( $( $t )* ).into());
58 }
59 }
60 macro_rules! debug {
61 ( $( $t:tt )* ) => {
62 web_sys::console::log_1(&format!( $( $t )* ).into());
63 }
64 }
65}
66
67if_not_wasm! {
68 use tokio_tungstenite::{
69 connect_async,
70 tungstenite::{
71 self,
72 protocol::CloseFrame,
73 },
74 };
75 type Message = tungstenite::protocol::Message;
76 type WsError = tungstenite::Error;
77 type WsStreamItem = Result<Message, WsError>;
78 use super::Authorization;
79 use tracing::{debug, error, warn};
80 use http::Request as HttpRequest;
81 use tungstenite::client::IntoClientRequest;
82}
83
84type Pending = oneshot::Sender<Result<Box<RawValue>, JsonRpcError>>;
85type Subscription = mpsc::UnboundedSender<Box<RawValue>>;
86
87enum Instruction {
89 Request { id: u64, request: String, sender: Pending },
91 Subscribe { id: U256, sink: Subscription },
93 Unsubscribe { id: U256 },
95}
96
97#[derive(Clone)]
110pub struct Ws {
111 id: Arc<AtomicU64>,
112 instructions: mpsc::UnboundedSender<Instruction>,
113}
114
115impl Debug for Ws {
116 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117 f.debug_struct("WebsocketProvider").field("id", &self.id).finish()
118 }
119}
120
121impl Ws {
122 pub fn new<S: 'static>(ws: S) -> Self
125 where
126 S: Send + Sync + Stream<Item = WsStreamItem> + Sink<Message, Error = WsError> + Unpin,
127 {
128 let (sink, stream) = mpsc::unbounded();
129 WsServer::new(ws, stream).spawn();
131
132 Self { id: Arc::new(AtomicU64::new(1)), instructions: sink }
133 }
134
135 pub fn ready(&self) -> bool {
137 !self.instructions.is_closed()
138 }
139
140 #[cfg(target_arch = "wasm32")]
142 pub async fn connect(url: &str) -> Result<Self, ClientError> {
143 let (_, wsio) = WsMeta::connect(url, None).await.expect_throw("Could not create websocket");
144
145 Ok(Self::new(wsio))
146 }
147
148 #[cfg(not(target_arch = "wasm32"))]
150 pub async fn connect(url: impl IntoClientRequest + Unpin) -> Result<Self, ClientError> {
151 let (ws, _) = connect_async(url).await?;
152 Ok(Self::new(ws))
153 }
154
155 #[cfg(not(target_arch = "wasm32"))]
157 pub async fn connect_with_auth(
158 uri: impl IntoClientRequest + Unpin,
159 auth: Authorization,
160 ) -> Result<Self, ClientError> {
161 let mut request: HttpRequest<()> = uri.into_client_request()?;
162
163 let mut auth_value = http::HeaderValue::from_str(&auth.to_string())?;
164 auth_value.set_sensitive(true);
165
166 request.headers_mut().insert(http::header::AUTHORIZATION, auth_value);
167 Self::connect(request).await
168 }
169
170 fn send(&self, msg: Instruction) -> Result<(), ClientError> {
171 self.instructions.unbounded_send(msg).map_err(to_client_error)
172 }
173}
174
175#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
176#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
177impl JsonRpcClient for Ws {
178 type Error = ClientError;
179
180 async fn request<T: Serialize + Send + Sync, R: DeserializeOwned>(
181 &self,
182 method: &str,
183 params: T,
184 ) -> Result<R, ClientError> {
185 let next_id = self.id.fetch_add(1, Ordering::SeqCst);
186
187 let (sender, receiver) = oneshot::channel();
189 let payload = Instruction::Request {
190 id: next_id,
191 request: serde_json::to_string(&Request::new(next_id, method, params))?,
192 sender,
193 };
194
195 self.send(payload)?;
197
198 let res = receiver.await??;
200
201 Ok(serde_json::from_str(res.get())?)
203 }
204}
205
206impl PubsubClient for Ws {
207 type NotificationStream = mpsc::UnboundedReceiver<Box<RawValue>>;
208
209 fn subscribe<T: Into<U256>>(&self, id: T) -> Result<Self::NotificationStream, ClientError> {
210 let (sink, stream) = mpsc::unbounded();
211 self.send(Instruction::Subscribe { id: id.into(), sink })?;
212 Ok(stream)
213 }
214
215 fn unsubscribe<T: Into<U256>>(&self, id: T) -> Result<(), ClientError> {
216 self.send(Instruction::Unsubscribe { id: id.into() })
217 }
218}
219
220struct WsServer<S> {
221 ws: Fuse<S>,
222 instructions: Fuse<mpsc::UnboundedReceiver<Instruction>>,
223
224 pending: BTreeMap<u64, Pending>,
225 subscriptions: BTreeMap<U256, Subscription>,
226}
227
228impl<S> WsServer<S>
229where
230 S: Send + Sync + Stream<Item = WsStreamItem> + Sink<Message, Error = WsError> + Unpin,
231{
232 fn new(ws: S, requests: mpsc::UnboundedReceiver<Instruction>) -> Self {
234 Self {
235 ws: ws.fuse(),
238 instructions: requests.fuse(),
239 pending: BTreeMap::default(),
240 subscriptions: BTreeMap::default(),
241 }
242 }
243
244 fn is_done(&self) -> bool {
249 self.instructions.is_done() && self.pending.is_empty() && self.subscriptions.is_empty()
250 }
251
252 fn spawn(mut self)
254 where
255 S: 'static,
256 {
257 let f = async move {
258 loop {
259 if self.is_done() {
260 debug!("work complete");
261 break
262 }
263
264 if let Err(e) = self.tick().await {
265 error!("Received a WebSocket error: {:?}", e);
266 self.close_all_subscriptions();
267 break
268 }
269 }
270 };
271
272 #[cfg(target_arch = "wasm32")]
273 spawn_local(f);
274
275 #[cfg(not(target_arch = "wasm32"))]
276 tokio::spawn(f);
277 }
278
279 fn close_all_subscriptions(&self) {
282 error!("Tearing down subscriptions");
283 for (_, sub) in self.subscriptions.iter() {
284 sub.close_channel();
285 }
286 }
287
288 async fn service_request(
290 &mut self,
291 id: u64,
292 request: String,
293 sender: Pending,
294 ) -> Result<(), ClientError> {
295 if self.pending.insert(id, sender).is_some() {
296 warn!("Replacing a pending request with id {:?}", id);
297 }
298
299 if let Err(e) = self.ws.send(Message::Text(request)).await {
300 error!("WS connection error: {:?}", e);
301 self.pending.remove(&id);
302 }
303 Ok(())
304 }
305
306 async fn service_subscribe(&mut self, id: U256, sink: Subscription) -> Result<(), ClientError> {
308 if self.subscriptions.insert(id, sink).is_some() {
309 warn!("Replacing already-registered subscription with id {:?}", id);
310 }
311 Ok(())
312 }
313
314 async fn service_unsubscribe(&mut self, id: U256) -> Result<(), ClientError> {
316 if self.subscriptions.remove(&id).is_none() {
317 warn!("Unsubscribing from non-existent subscription with id {:?}", id);
318 }
319 Ok(())
320 }
321
322 async fn service(&mut self, instruction: Instruction) -> Result<(), ClientError> {
324 match instruction {
325 Instruction::Request { id, request, sender } => {
326 self.service_request(id, request, sender).await
327 }
328 Instruction::Subscribe { id, sink } => self.service_subscribe(id, sink).await,
329 Instruction::Unsubscribe { id } => self.service_unsubscribe(id).await,
330 }
331 }
332
333 #[cfg(not(target_arch = "wasm32"))]
334 async fn handle_ping(&mut self, inner: Vec<u8>) -> Result<(), ClientError> {
335 self.ws.send(Message::Pong(inner)).await?;
336 Ok(())
337 }
338
339 async fn handle_text(&mut self, inner: String) -> Result<(), ClientError> {
340 trace!(msg=?inner, "received message");
341 let (id, result) = match serde_json::from_str(&inner)? {
342 Response::Success { id, result } => (id, Ok(result.to_owned())),
343 Response::Error { id, error } => (id, Err(error)),
344 Response::Notification { params, .. } => return self.handle_notification(params),
345 };
346
347 if let Some(request) = self.pending.remove(&id) {
348 if !request.is_canceled() {
349 request.send(result).map_err(to_client_error)?;
350 }
351 }
352
353 Ok(())
354 }
355
356 fn handle_notification(&mut self, params: Params<'_>) -> Result<(), ClientError> {
357 let id = params.subscription;
358 if let Entry::Occupied(stream) = self.subscriptions.entry(id) {
359 if let Err(err) = stream.get().unbounded_send(params.result.to_owned()) {
360 if err.is_disconnected() {
361 stream.remove();
363 }
364 return Err(to_client_error(err))
365 }
366 }
367
368 Ok(())
369 }
370
371 #[cfg(target_arch = "wasm32")]
372 async fn handle(&mut self, resp: Message) -> Result<(), ClientError> {
373 match resp {
374 Message::Text(inner) => self.handle_text(inner).await,
375 Message::Binary(buf) => Err(ClientError::UnexpectedBinary(buf)),
376 }
377 }
378
379 #[cfg(not(target_arch = "wasm32"))]
380 async fn handle(&mut self, resp: Message) -> Result<(), ClientError> {
381 match resp {
382 Message::Text(inner) => self.handle_text(inner).await,
383 Message::Frame(_) => Ok(()), Message::Ping(inner) => self.handle_ping(inner).await,
385 Message::Pong(_) => Ok(()), Message::Close(Some(frame)) => Err(ClientError::WsClosed(frame)),
387 Message::Close(None) => Err(ClientError::UnexpectedClose),
388 Message::Binary(buf) => Err(ClientError::UnexpectedBinary(buf)),
389 }
390 }
391
392 #[allow(clippy::single_match)]
394 #[cfg(target_arch = "wasm32")]
395 async fn tick(&mut self) -> Result<(), ClientError> {
396 futures_util::select! {
397 instruction = self.instructions.select_next_some() => {
399 self.service(instruction).await?;
400 },
401 resp = self.ws.next() => match resp {
403 Some(resp) => self.handle(resp).await?,
404 None => {
405 return Err(ClientError::UnexpectedClose);
406 },
407 }
408 };
409
410 Ok(())
411 }
412
413 #[allow(clippy::single_match)]
415 #[cfg(not(target_arch = "wasm32"))]
416 async fn tick(&mut self) -> Result<(), ClientError> {
417 futures_util::select! {
418 instruction = self.instructions.select_next_some() => {
420 self.service(instruction).await?;
421 },
422 resp = self.ws.next() => match resp {
424 Some(Ok(resp)) => self.handle(resp).await?,
425 Some(Err(err)) => {
426 tracing::error!(?err);
427 return Err(ClientError::UnexpectedClose);
428 }
429 None => {
430 return Err(ClientError::UnexpectedClose);
431 },
432 }
433 };
434
435 Ok(())
436 }
437}
438
439fn to_client_error<T: Debug>(err: T) -> ClientError {
441 ClientError::ChannelError(format!("{err:?}"))
442}
443
444#[derive(Debug, Error)]
446pub enum ClientError {
447 #[error(transparent)]
449 JsonError(#[from] serde_json::Error),
450
451 #[error(transparent)]
452 JsonRpcError(#[from] JsonRpcError),
454
455 #[error("Websocket responded with unexpected binary data")]
457 UnexpectedBinary(Vec<u8>),
458
459 #[error(transparent)]
461 TungsteniteError(#[from] WsError),
462
463 #[error("{0}")]
464 ChannelError(String),
466
467 #[error("{0}")]
468 Canceled(#[from] oneshot::Canceled),
470
471 #[error("Websocket closed with info: {0:?}")]
473 #[cfg(not(target_arch = "wasm32"))]
474 WsClosed(CloseFrame<'static>),
475
476 #[error("Websocket closed")]
478 #[cfg(target_arch = "wasm32")]
479 WsClosed,
480
481 #[error("WebSocket connection closed unexpectedly")]
483 UnexpectedClose,
484
485 #[error(transparent)]
487 #[cfg(not(target_arch = "wasm32"))]
488 WsAuth(#[from] http::header::InvalidHeaderValue),
489
490 #[error(transparent)]
492 #[cfg(not(target_arch = "wasm32"))]
493 UriError(#[from] http::uri::InvalidUri),
494
495 #[error(transparent)]
497 #[cfg(not(target_arch = "wasm32"))]
498 RequestError(#[from] http::Error),
499}
500
501impl crate::RpcError for ClientError {
502 fn as_error_response(&self) -> Option<&super::JsonRpcError> {
503 if let ClientError::JsonRpcError(err) = self {
504 Some(err)
505 } else {
506 None
507 }
508 }
509
510 fn as_serde_error(&self) -> Option<&serde_json::Error> {
511 match self {
512 ClientError::JsonError(err) => Some(err),
513 _ => None,
514 }
515 }
516}
517
518impl From<ClientError> for ProviderError {
519 fn from(src: ClientError) -> Self {
520 ProviderError::JsonRpcClientError(Box::new(src))
521 }
522}
523
524#[cfg(all(test, not(target_arch = "wasm32")))]
525mod tests {
526 use super::*;
527 use ethers_core::utils::Anvil;
528
529 #[tokio::test]
530 async fn request() {
531 let anvil = Anvil::new().block_time(1u64).spawn();
532 let ws = Ws::connect(anvil.ws_endpoint()).await.unwrap();
533
534 let block_num: U256 = ws.request("eth_blockNumber", ()).await.unwrap();
535 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
536 let block_num2: U256 = ws.request("eth_blockNumber", ()).await.unwrap();
537 assert!(block_num2 > block_num);
538 }
539
540 #[tokio::test]
541 #[cfg(not(feature = "celo"))]
542 async fn subscription() {
543 use ethers_core::types::{Block, TxHash};
544
545 let anvil = Anvil::new().block_time(1u64).spawn();
546 let ws = Ws::connect(anvil.ws_endpoint()).await.unwrap();
547
548 let sub_id: U256 = ws.request("eth_subscribe", ["newHeads"]).await.unwrap();
551 let stream = ws.subscribe(sub_id).unwrap();
552
553 let blocks: Vec<u64> = stream
554 .take(3)
555 .map(|item| {
556 let block: Block<TxHash> = serde_json::from_str(item.get()).unwrap();
557 block.number.unwrap_or_default().as_u64()
558 })
559 .collect()
560 .await;
561 assert_eq!(blocks, vec![1, 2, 3]);
562 }
563
564 #[tokio::test]
565 async fn deserialization_fails() {
566 let anvil = Anvil::new().block_time(1u64).spawn();
567 let (ws, _) = tokio_tungstenite::connect_async(anvil.ws_endpoint()).await.unwrap();
568 let malformed_data = String::from("not a valid message");
569 let (_, stream) = mpsc::unbounded();
570 let resp = WsServer::new(ws, stream).handle_text(malformed_data).await;
571 resp.unwrap_err();
572 }
573}
574
575impl crate::Provider<Ws> {
576 #[cfg(not(target_arch = "wasm32"))]
578 pub async fn connect(
579 url: impl tokio_tungstenite::tungstenite::client::IntoClientRequest + Unpin,
580 ) -> Result<Self, ProviderError> {
581 let ws = crate::Ws::connect(url).await?;
582 Ok(Self::new(ws))
583 }
584
585 #[cfg(target_arch = "wasm32")]
587 pub async fn connect(url: &str) -> Result<Self, ProviderError> {
588 let ws = crate::Ws::connect(url).await?;
589 Ok(Self::new(ws))
590 }
591
592 #[cfg(not(target_arch = "wasm32"))]
594 pub async fn connect_with_auth(
595 url: impl tokio_tungstenite::tungstenite::client::IntoClientRequest + Unpin,
596 auth: Authorization,
597 ) -> Result<Self, ProviderError> {
598 let ws = crate::Ws::connect_with_auth(url, auth).await?;
599 Ok(Self::new(ws))
600 }
601}