alloy_transport_ws/
native.rs

1use crate::WsBackend;
2use alloy_pubsub::PubSubConnect;
3use alloy_transport::{utils::Spawnable, Authorization, TransportErrorKind, TransportResult};
4use futures::{SinkExt, StreamExt};
5use serde_json::value::RawValue;
6use std::time::Duration;
7use tokio::time::sleep;
8use tokio_tungstenite::{
9    tungstenite::{self, client::IntoClientRequest, Message},
10    MaybeTlsStream, WebSocketStream,
11};
12
13type TungsteniteStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
14
15pub use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
16
17const KEEPALIVE: u64 = 10;
18
19/// Simple connection details for a websocket connection.
20#[derive(Clone, Debug)]
21pub struct WsConnect {
22    /// The URL to connect to.
23    url: String,
24    /// The authorization header to use.
25    auth: Option<Authorization>,
26    /// The websocket config.
27    config: Option<WebSocketConfig>,
28    /// Max number of retries before failing and exiting the connection.
29    /// Default is 10.
30    max_retries: u32,
31    /// The interval between retries.
32    /// Default is 3 seconds.
33    retry_interval: Duration,
34}
35
36impl WsConnect {
37    /// Creates a new websocket connection configuration.
38    pub fn new<S: Into<String>>(url: S) -> Self {
39        Self {
40            url: url.into(),
41            auth: None,
42            config: None,
43            max_retries: 10,
44            retry_interval: Duration::from_secs(3),
45        }
46    }
47
48    /// Sets the authorization header.
49    pub fn with_auth(mut self, auth: Authorization) -> Self {
50        self.auth = Some(auth);
51        self
52    }
53
54    /// Sets the optional authorization header.
55    ///
56    /// This replaces the current [`Authorization`].
57    pub fn with_auth_opt(mut self, auth: Option<Authorization>) -> Self {
58        self.auth = auth;
59        self
60    }
61
62    /// Sets the websocket config.
63    pub const fn with_config(mut self, config: WebSocketConfig) -> Self {
64        self.config = Some(config);
65        self
66    }
67
68    /// Get the URL string of the connection.
69    pub fn url(&self) -> &str {
70        &self.url
71    }
72
73    /// Get the authorization header.
74    pub const fn auth(&self) -> Option<&Authorization> {
75        self.auth.as_ref()
76    }
77
78    /// Get the websocket config.
79    pub const fn config(&self) -> Option<&WebSocketConfig> {
80        self.config.as_ref()
81    }
82
83    /// Sets the max number of retries before failing and exiting the connection.
84    /// Default is 10.
85    pub const fn with_max_retries(mut self, max_retries: u32) -> Self {
86        self.max_retries = max_retries;
87        self
88    }
89
90    /// Sets the interval between retries.
91    /// Default is 3 seconds.
92    pub const fn with_retry_interval(mut self, retry_interval: Duration) -> Self {
93        self.retry_interval = retry_interval;
94        self
95    }
96}
97
98impl IntoClientRequest for WsConnect {
99    fn into_client_request(self) -> tungstenite::Result<tungstenite::handshake::client::Request> {
100        let mut request: http::Request<()> = self.url.into_client_request()?;
101        if let Some(auth) = self.auth {
102            let mut auth_value = http::HeaderValue::from_str(&auth.to_string())?;
103            auth_value.set_sensitive(true);
104
105            request.headers_mut().insert(http::header::AUTHORIZATION, auth_value);
106        }
107
108        request.into_client_request()
109    }
110}
111
112impl PubSubConnect for WsConnect {
113    fn is_local(&self) -> bool {
114        alloy_transport::utils::guess_local_url(&self.url)
115    }
116
117    async fn connect(&self) -> TransportResult<alloy_pubsub::ConnectionHandle> {
118        let request = self.clone().into_client_request();
119        let req = request.map_err(TransportErrorKind::custom)?;
120        let (socket, _) = tokio_tungstenite::connect_async_with_config(req, self.config, false)
121            .await
122            .map_err(TransportErrorKind::custom)?;
123
124        let (handle, interface) = alloy_pubsub::ConnectionHandle::new();
125        let backend = WsBackend { socket, interface };
126
127        backend.spawn();
128
129        Ok(handle.with_max_retries(self.max_retries).with_retry_interval(self.retry_interval))
130    }
131}
132
133impl WsBackend<TungsteniteStream> {
134    /// Handle a message from the server.
135    #[expect(clippy::result_unit_err)]
136    pub fn handle(&mut self, msg: Message) -> Result<(), ()> {
137        match msg {
138            Message::Text(text) => self.handle_text(&text),
139            Message::Close(frame) => {
140                if frame.is_some() {
141                    error!(?frame, "Received close frame with data");
142                } else {
143                    error!("WS server has gone away");
144                }
145                Err(())
146            }
147            Message::Binary(_) => {
148                error!("Received binary message, expected text");
149                Err(())
150            }
151            Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => Ok(()),
152        }
153    }
154
155    /// Send a message to the server.
156    pub async fn send(&mut self, msg: Box<RawValue>) -> Result<(), tungstenite::Error> {
157        self.socket.send(Message::Text(msg.get().to_owned().into())).await
158    }
159
160    /// Spawn a new backend task.
161    pub fn spawn(mut self) {
162        let fut = async move {
163            let mut errored = false;
164            let mut expecting_pong = false;
165            let keepalive = sleep(Duration::from_secs(KEEPALIVE));
166            tokio::pin!(keepalive);
167            loop {
168                // We bias the loop as follows
169                // 1. New dispatch to server.
170                // 2. Keepalive.
171                // 3. Response or notification from server.
172                // This ensures that keepalive is sent only if no other messages
173                // have been sent in the last 10 seconds. And prioritizes new
174                // dispatches over responses from the server. This will fail if
175                // the client saturates the task with dispatches, but that's
176                // probably not a big deal.
177                tokio::select! {
178                    biased;
179                    // we've received a new dispatch, so we send it via
180                    // websocket. We handle new work before processing any
181                    // responses from the server.
182                    inst = self.interface.recv_from_frontend() => {
183                        match inst {
184                            Some(msg) => {
185                                // Reset the keepalive timer.
186                                keepalive.set(sleep(Duration::from_secs(KEEPALIVE)));
187                                if let Err(err) = self.send(msg).await {
188                                    error!(%err, "WS connection error");
189                                    errored = true;
190                                    break
191                                }
192                            },
193                            // dispatcher has gone away, or shutdown was received
194                            None => {
195                                break
196                            },
197                        }
198                    },
199                    // Send a ping to the server, if no other messages have been
200                    // sent in the last 10 seconds.
201                    _ = &mut keepalive => {
202                        // Still expecting a pong from the previous ping,
203                        // meaning connection is errored.
204                        if expecting_pong {
205                            error!("WS server missed a pong");
206                            errored = true;
207                            break
208                        }
209                        // Reset the keepalive timer.
210                        keepalive.set(sleep(Duration::from_secs(KEEPALIVE)));
211                        if let Err(err) = self.socket.send(Message::Ping(Default::default())).await {
212                            error!(%err, "WS connection error");
213                            errored = true;
214                            break
215                        }
216                        // Expecting to receive a pong before the next
217                        // keepalive timer resolves.
218                        expecting_pong = true;
219                    }
220                    resp = self.socket.next() => {
221                        match resp {
222                            Some(Ok(item)) => {
223                                if item.is_pong() {
224                                    expecting_pong = false;
225                                }
226                                errored = self.handle(item).is_err();
227                                if errored { break }
228                            },
229                            Some(Err(err)) => {
230                                error!(%err, "WS connection error");
231                                errored = true;
232                                break
233                            }
234                            None => {
235                                error!("WS server has gone away");
236                                errored = true;
237                                break
238                            },
239                        }
240                    }
241                }
242            }
243            if errored {
244                self.interface.close_with_error();
245            }
246        };
247        fut.spawn_task()
248    }
249}