alloy_transport_ws/
native.rs1use crate::WsBackend;
2use alloy_pubsub::PubSubConnect;
3use alloy_transport::{utils::Spawnable, Authorization, TransportErrorKind, TransportResult};
4use futures::{SinkExt, StreamExt};
5use serde_json::value::RawValue;
6use std::time::Duration;
7use tokio::time::sleep;
8use tokio_tungstenite::{
9 tungstenite::{self, client::IntoClientRequest, Message},
10 MaybeTlsStream, WebSocketStream,
11};
12
13type TungsteniteStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
14
15pub use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
16
17const KEEPALIVE: u64 = 10;
18
19#[derive(Clone, Debug)]
21pub struct WsConnect {
22 url: String,
24 auth: Option<Authorization>,
26 config: Option<WebSocketConfig>,
28 max_retries: u32,
31 retry_interval: Duration,
34}
35
36impl WsConnect {
37 pub fn new<S: Into<String>>(url: S) -> Self {
39 Self {
40 url: url.into(),
41 auth: None,
42 config: None,
43 max_retries: 10,
44 retry_interval: Duration::from_secs(3),
45 }
46 }
47
48 pub fn with_auth(mut self, auth: Authorization) -> Self {
50 self.auth = Some(auth);
51 self
52 }
53
54 pub fn with_auth_opt(mut self, auth: Option<Authorization>) -> Self {
58 self.auth = auth;
59 self
60 }
61
62 pub const fn with_config(mut self, config: WebSocketConfig) -> Self {
64 self.config = Some(config);
65 self
66 }
67
68 pub fn url(&self) -> &str {
70 &self.url
71 }
72
73 pub const fn auth(&self) -> Option<&Authorization> {
75 self.auth.as_ref()
76 }
77
78 pub const fn config(&self) -> Option<&WebSocketConfig> {
80 self.config.as_ref()
81 }
82
83 pub const fn with_max_retries(mut self, max_retries: u32) -> Self {
86 self.max_retries = max_retries;
87 self
88 }
89
90 pub const fn with_retry_interval(mut self, retry_interval: Duration) -> Self {
93 self.retry_interval = retry_interval;
94 self
95 }
96}
97
98impl IntoClientRequest for WsConnect {
99 fn into_client_request(self) -> tungstenite::Result<tungstenite::handshake::client::Request> {
100 let mut request: http::Request<()> = self.url.into_client_request()?;
101 if let Some(auth) = self.auth {
102 let mut auth_value = http::HeaderValue::from_str(&auth.to_string())?;
103 auth_value.set_sensitive(true);
104
105 request.headers_mut().insert(http::header::AUTHORIZATION, auth_value);
106 }
107
108 request.into_client_request()
109 }
110}
111
112impl PubSubConnect for WsConnect {
113 fn is_local(&self) -> bool {
114 alloy_transport::utils::guess_local_url(&self.url)
115 }
116
117 async fn connect(&self) -> TransportResult<alloy_pubsub::ConnectionHandle> {
118 let request = self.clone().into_client_request();
119 let req = request.map_err(TransportErrorKind::custom)?;
120 let (socket, _) = tokio_tungstenite::connect_async_with_config(req, self.config, false)
121 .await
122 .map_err(TransportErrorKind::custom)?;
123
124 let (handle, interface) = alloy_pubsub::ConnectionHandle::new();
125 let backend = WsBackend { socket, interface };
126
127 backend.spawn();
128
129 Ok(handle.with_max_retries(self.max_retries).with_retry_interval(self.retry_interval))
130 }
131}
132
133impl WsBackend<TungsteniteStream> {
134 #[expect(clippy::result_unit_err)]
136 pub fn handle(&mut self, msg: Message) -> Result<(), ()> {
137 match msg {
138 Message::Text(text) => self.handle_text(&text),
139 Message::Close(frame) => {
140 if frame.is_some() {
141 error!(?frame, "Received close frame with data");
142 } else {
143 error!("WS server has gone away");
144 }
145 Err(())
146 }
147 Message::Binary(_) => {
148 error!("Received binary message, expected text");
149 Err(())
150 }
151 Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => Ok(()),
152 }
153 }
154
155 pub async fn send(&mut self, msg: Box<RawValue>) -> Result<(), tungstenite::Error> {
157 self.socket.send(Message::Text(msg.get().to_owned().into())).await
158 }
159
160 pub fn spawn(mut self) {
162 let fut = async move {
163 let mut errored = false;
164 let mut expecting_pong = false;
165 let keepalive = sleep(Duration::from_secs(KEEPALIVE));
166 tokio::pin!(keepalive);
167 loop {
168 tokio::select! {
178 biased;
179 inst = self.interface.recv_from_frontend() => {
183 match inst {
184 Some(msg) => {
185 keepalive.set(sleep(Duration::from_secs(KEEPALIVE)));
187 if let Err(err) = self.send(msg).await {
188 error!(%err, "WS connection error");
189 errored = true;
190 break
191 }
192 },
193 None => {
195 break
196 },
197 }
198 },
199 _ = &mut keepalive => {
202 if expecting_pong {
205 error!("WS server missed a pong");
206 errored = true;
207 break
208 }
209 keepalive.set(sleep(Duration::from_secs(KEEPALIVE)));
211 if let Err(err) = self.socket.send(Message::Ping(Default::default())).await {
212 error!(%err, "WS connection error");
213 errored = true;
214 break
215 }
216 expecting_pong = true;
219 }
220 resp = self.socket.next() => {
221 match resp {
222 Some(Ok(item)) => {
223 if item.is_pong() {
224 expecting_pong = false;
225 }
226 errored = self.handle(item).is_err();
227 if errored { break }
228 },
229 Some(Err(err)) => {
230 error!(%err, "WS connection error");
231 errored = true;
232 break
233 }
234 None => {
235 error!("WS server has gone away");
236 errored = true;
237 break
238 },
239 }
240 }
241 }
242 }
243 if errored {
244 self.interface.close_with_error();
245 }
246 };
247 fut.spawn_task()
248 }
249}