alloy_transport_ws/
native.rs

1use crate::WsBackend;
2use alloy_pubsub::PubSubConnect;
3use alloy_transport::{utils::Spawnable, Authorization, TransportErrorKind, TransportResult};
4use futures::{SinkExt, StreamExt};
5use serde_json::value::RawValue;
6use std::time::Duration;
7pub use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
8use tokio_tungstenite::{
9    tungstenite::{self, client::IntoClientRequest, Message},
10    MaybeTlsStream, WebSocketStream,
11};
12
13#[cfg(target_arch = "wasm32")]
14use wasmtimer::tokio::sleep;
15
16#[cfg(not(target_arch = "wasm32"))]
17use tokio::time::sleep;
18
19type TungsteniteStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
20
21const KEEPALIVE: u64 = 10;
22
23/// Simple connection details for a websocket connection.
24#[derive(Clone, Debug)]
25pub struct WsConnect {
26    /// The URL to connect to.
27    pub url: String,
28    /// The authorization header to use.
29    pub auth: Option<Authorization>,
30    /// The websocket config.
31    pub config: Option<WebSocketConfig>,
32}
33
34impl WsConnect {
35    /// Creates a new websocket connection configuration.
36    pub fn new<S: Into<String>>(url: S) -> Self {
37        Self { url: url.into(), auth: None, config: None }
38    }
39
40    /// Sets the authorization header.
41    pub fn with_auth(mut self, auth: Authorization) -> Self {
42        self.auth = Some(auth);
43        self
44    }
45
46    /// Sets the websocket config.
47    pub const fn with_config(mut self, config: WebSocketConfig) -> Self {
48        self.config = Some(config);
49        self
50    }
51}
52
53impl IntoClientRequest for WsConnect {
54    fn into_client_request(self) -> tungstenite::Result<tungstenite::handshake::client::Request> {
55        let mut request: http::Request<()> = self.url.into_client_request()?;
56        if let Some(auth) = self.auth {
57            let mut auth_value = http::HeaderValue::from_str(&auth.to_string())?;
58            auth_value.set_sensitive(true);
59
60            request.headers_mut().insert(http::header::AUTHORIZATION, auth_value);
61        }
62
63        request.into_client_request()
64    }
65}
66
67impl PubSubConnect for WsConnect {
68    fn is_local(&self) -> bool {
69        alloy_transport::utils::guess_local_url(&self.url)
70    }
71
72    async fn connect(&self) -> TransportResult<alloy_pubsub::ConnectionHandle> {
73        let request = self.clone().into_client_request();
74        let req = request.map_err(TransportErrorKind::custom)?;
75        let (socket, _) = tokio_tungstenite::connect_async_with_config(req, self.config, false)
76            .await
77            .map_err(TransportErrorKind::custom)?;
78
79        let (handle, interface) = alloy_pubsub::ConnectionHandle::new();
80        let backend = WsBackend { socket, interface };
81
82        backend.spawn();
83
84        Ok(handle)
85    }
86}
87
88impl WsBackend<TungsteniteStream> {
89    /// Handle a message from the server.
90    #[allow(clippy::result_unit_err)]
91    pub fn handle(&mut self, msg: Message) -> Result<(), ()> {
92        match msg {
93            Message::Text(text) => self.handle_text(&text),
94            Message::Close(frame) => {
95                if frame.is_some() {
96                    error!(?frame, "Received close frame with data");
97                } else {
98                    error!("WS server has gone away");
99                }
100                Err(())
101            }
102            Message::Binary(_) => {
103                error!("Received binary message, expected text");
104                Err(())
105            }
106            Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => Ok(()),
107        }
108    }
109
110    /// Send a message to the server.
111    pub async fn send(&mut self, msg: Box<RawValue>) -> Result<(), tungstenite::Error> {
112        self.socket.send(Message::Text(msg.get().to_owned().into())).await
113    }
114
115    /// Spawn a new backend task.
116    pub fn spawn(mut self) {
117        let fut = async move {
118            let mut errored = false;
119            let mut expecting_pong = false;
120            let keepalive = sleep(Duration::from_secs(KEEPALIVE));
121            tokio::pin!(keepalive);
122            loop {
123                // We bias the loop as follows
124                // 1. New dispatch to server.
125                // 2. Keepalive.
126                // 3. Response or notification from server.
127                // This ensures that keepalive is sent only if no other messages
128                // have been sent in the last 10 seconds. And prioritizes new
129                // dispatches over responses from the server. This will fail if
130                // the client saturates the task with dispatches, but that's
131                // probably not a big deal.
132                tokio::select! {
133                    biased;
134                    // we've received a new dispatch, so we send it via
135                    // websocket. We handle new work before processing any
136                    // responses from the server.
137                    inst = self.interface.recv_from_frontend() => {
138                        match inst {
139                            Some(msg) => {
140                                // Reset the keepalive timer.
141                                keepalive.set(sleep(Duration::from_secs(KEEPALIVE)));
142                                if let Err(err) = self.send(msg).await {
143                                    error!(%err, "WS connection error");
144                                    errored = true;
145                                    break
146                                }
147                            },
148                            // dispatcher has gone away, or shutdown was received
149                            None => {
150                                break
151                            },
152                        }
153                    },
154                    // Send a ping to the server, if no other messages have been
155                    // sent in the last 10 seconds.
156                    _ = &mut keepalive => {
157                        // Still expecting a pong from the previous ping,
158                        // meaning connection is errored.
159                        if expecting_pong {
160                            error!("WS server missed a pong");
161                            errored = true;
162                            break
163                        }
164                        // Reset the keepalive timer.
165                        keepalive.set(sleep(Duration::from_secs(KEEPALIVE)));
166                        if let Err(err) = self.socket.send(Message::Ping(Default::default())).await {
167                            error!(%err, "WS connection error");
168                            errored = true;
169                            break
170                        }
171                        // Expecting to receive a pong before the next
172                        // keepalive timer resolves.
173                        expecting_pong = true;
174                    }
175                    resp = self.socket.next() => {
176                        match resp {
177                            Some(Ok(item)) => {
178                                if item.is_pong() {
179                                    expecting_pong = false;
180                                }
181                                errored = self.handle(item).is_err();
182                                if errored { break }
183                            },
184                            Some(Err(err)) => {
185                                error!(%err, "WS connection error");
186                                errored = true;
187                                break
188                            }
189                            None => {
190                                error!("WS server has gone away");
191                                errored = true;
192                                break
193                            },
194                        }
195                    }
196                }
197            }
198            if errored {
199                self.interface.close_with_error();
200            }
201        };
202        fut.spawn_task()
203    }
204}